Пример #1
0
 def testGetVersionFromCheckpoint(self):
     with tempfile.TemporaryDirectory() as tempdir:
         self.params.version = 100
         ckpt_dir = save_variables_to_checkpoint(tempdir, self.params)
         ckpt_version_dir = os.path.join(ckpt_dir, "version-100")
         model_version = CheckpointSaver.get_version_from_checkpoint(
             ckpt_version_dir)
         self.assertTrue(model_version, 100)
Пример #2
0
    def _set_completed_steps_by_checkpoint(self, checkpoint_dir_for_init):
        if not checkpoint_dir_for_init:
            return

        if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init):
            raise ValueError("Invalid checkpoint directory {}".format(
                checkpoint_dir_for_init))

        self._completed_steps = CheckpointSaver.get_version_from_checkpoint(
            checkpoint_dir_for_init)
Пример #3
0
    def _set_completed_steps_by_checkpoint(self, checkpoint_dir_for_init):
        if not checkpoint_dir_for_init:
            return

        if not CheckpointSaver.check_checkpoint_valid(checkpoint_dir_for_init):
            raise ValueError("Invalid checkpoint directory {}".format(
                checkpoint_dir_for_init))

        model_verion = CheckpointSaver.get_version_from_checkpoint(
            checkpoint_dir_for_init)
        for callback in self.callbacks_list.callbacks:
            if isinstance(callback, MaxStepsStopping):
                callback.set_completed_steps(model_verion)