def __init__(self, model, checkpoint_dir, save_freq="epoch"):
        self._model = model
        self._save_freq = save_freq
        # The batch and epoch at which the checkpoint is saved. Used for
        # fault-tolerance. GPU device only has int64 dtype registered
        # VarHandleOp.
        self._ckpt_saved_epoch = tf.Variable(
            initial_value=tf.constant(
                self.CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=tf.int64
            ),
            name="ckpt_saved_epoch",
        )
        self._ckpt_saved_batch = tf.Variable(
            initial_value=tf.constant(
                self.CKPT_SAVED_BATCH_UNUSED_VALUE, dtype=tf.int64
            ),
            name="ckpt_saved_batch",
        )
        # Variable initialization.
        backend.set_value(
            self._ckpt_saved_epoch, self.CKPT_SAVED_EPOCH_UNUSED_VALUE
        )
        backend.set_value(
            self._ckpt_saved_batch, self.CKPT_SAVED_BATCH_UNUSED_VALUE
        )
        # _ckpt_saved_epoch  and _ckpt_saved_batch gets tracked and is included
        # in the checkpoint file when backing up.
        checkpoint = tf.train.Checkpoint(
            model=self._model,
            ckpt_saved_epoch=self._ckpt_saved_epoch,
            ckpt_saved_batch=self._ckpt_saved_batch,
            train_counter=self._model._train_counter,
        )

        # If this is single-worker training, checkpoint_dir are the same for
        # write_checkpoint_manager and read_checkpoint_manager.
        #
        # If this is multi-worker training, and this worker should not save
        # checkpoint, we replace the write_checkpoint_manager's checkpoint_dir
        # with a temp filepath, so it writes to a file that will be removed at
        # the end of back_up() call. This is necessary because the
        # SyncOnReadVariable needs to be synced across all the workers in order
        # to be read, and all workers need to perform `save()`.  But all workers
        # should restore from the same checkpoint_dir as passed in
        # read_checkpoint_manager.
        self.read_checkpoint_manager = tf.train.CheckpointManager(
            checkpoint,
            directory=os.path.join(checkpoint_dir, "chief"),
            max_to_keep=1,
        )
        write_checkpoint_dir = distributed_file_utils.write_dirpath(
            checkpoint_dir, self._model.distribute_strategy
        )
        if self._model.distribute_strategy.extended.should_checkpoint:
            self.write_checkpoint_manager = self.read_checkpoint_manager
        else:
            self.write_checkpoint_manager = tf.train.CheckpointManager(
                checkpoint, directory=write_checkpoint_dir, max_to_keep=1
            )
 def testChiefWriteDirAndFilePath(self):
   dirpath = self.get_temp_dir()
   filepath = os.path.join(dirpath, 'foo.bar')
   strategy = DistributedFileUtilsTest.MockedChiefStrategy()
   self.assertEqual(
       distributed_file_utils.write_filepath(filepath, strategy), filepath)
   self.assertEqual(
       distributed_file_utils.write_dirpath(dirpath, strategy), dirpath)
Beispiel #3
0
 def testMultipleRemoveDirToWritePathIsFine(self):
     temp_dir = self.get_temp_dir()
     strategy = DistributedFileUtilsTest.MockedWorkerStrategy()
     dir_to_write = distributed_file_utils.write_dirpath(temp_dir, strategy)
     file_to_write = os.path.join(dir_to_write, "tmp")
     self._write_dummy_file(file_to_write)
     distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy)
     distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy)
     distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy)
Beispiel #4
0
 def testWorkerWriteDirAndFilePath(self):
     dirpath = self.get_temp_dir()
     filepath = os.path.join(dirpath, 'foo.bar')
     strategy = DistributedFileUtilsTest.MockedWorkerStrategy()
     self.assertEqual(
         distributed_file_utils.write_filepath(filepath, strategy),
         os.path.join(dirpath, 'workertemp_3', 'foo.bar'))
     self.assertEqual(
         distributed_file_utils.write_dirpath(dirpath, strategy),
         os.path.join(dirpath, 'workertemp_3'))
Beispiel #5
0
 def testWorkerDoesRemoveDirPath(self):
     temp_dir = self.get_temp_dir()
     strategy = DistributedFileUtilsTest.MockedWorkerStrategy()
     dir_to_write = distributed_file_utils.write_dirpath(temp_dir, strategy)
     file_to_write = os.path.join(dir_to_write, "tmp")
     self.assertFalse(os.path.exists(file_to_write))
     self._write_dummy_file(file_to_write)
     self.assertTrue(os.path.exists(file_to_write))
     distributed_file_utils.remove_temp_dirpath(temp_dir, strategy)
     self.assertFalse(os.path.exists(file_to_write))
     self.assertFalse(os.path.exists(os.path.dirname(file_to_write)))
Beispiel #6
0
 def testChiefDoesNotRemoveDirAndFilePath(self):
     temp_dir = self.get_temp_dir()
     strategy = DistributedFileUtilsTest.MockedChiefStrategy()
     dir_to_write = distributed_file_utils.write_dirpath(temp_dir, strategy)
     file_to_write = os.path.join(dir_to_write, 'tmp')
     self.assertFalse(os.path.exists(file_to_write))
     self._write_dummy_file(file_to_write)
     self.assertTrue(os.path.exists(file_to_write))
     distributed_file_utils.remove_temp_dir_with_filepath(
         file_to_write, strategy)
     self.assertTrue(os.path.exists(file_to_write))