Example #1
0
    def testPauseResumeCheckpointCount(self):
        ray.init(num_cpus=2)
        tempdir = tempfile.mkdtemp()
        self.addCleanup(shutil.rmtree, tempdir)

        trial = Trial("__fake", keep_checkpoints_num=2)
        trial.init_logdir()
        trial.checkpoint_manager.delete = lambda cp: shutil.rmtree(cp.value)

        def write_checkpoint(trial: Trial, index: int):
            checkpoint_dir = TrainableUtil.make_checkpoint_dir(trial.logdir,
                                                               index=index)
            result = {"training_iteration": index}
            with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f:
                json.dump(result, f)

            tune_cp = _TuneCheckpoint(_TuneCheckpoint.PERSISTENT,
                                      checkpoint_dir, result)
            trial.saving_to = tune_cp
            trial.on_checkpoint(tune_cp)

            return checkpoint_dir

        def get_checkpoint_dirs(trial: Trial):
            return [
                d for d in os.listdir(trial.logdir)
                if d.startswith("checkpoint_")
            ]

        runner = TrialRunner(local_checkpoint_dir=tempdir)
        runner.add_trial(trial)

        # Write 1 checkpoint
        result = write_checkpoint(trial, 1)
        runner._on_saving_result(trial, result)

        # Expect 1 checkpoint
        cp_dirs = get_checkpoint_dirs(trial)
        self.assertEqual(len(cp_dirs), 1, msg=f"Checkpoint dirs: {cp_dirs}")

        # Write second checkpoint
        result = write_checkpoint(trial, 2)
        runner._on_saving_result(trial, result)

        # Expect 2 checkpoints
        cp_dirs = get_checkpoint_dirs(trial)
        self.assertEqual(len(cp_dirs), 2, msg=f"Checkpoint dirs: {cp_dirs}")

        # Write third checkpoint
        result = write_checkpoint(trial, 3)
        runner._on_saving_result(trial, result)

        # Expect 2 checkpoints because keep_checkpoints_num = 2
        cp_dirs = get_checkpoint_dirs(trial)
        self.assertEqual(len(cp_dirs), 2, msg=f"Checkpoint dirs: {cp_dirs}")

        # Re-instantiate trial runner and resume
        runner.checkpoint(force=True)
        runner = TrialRunner(local_checkpoint_dir=tempdir)
        runner.resume()

        trial = runner.get_trials()[0]
        trial.checkpoint_manager.delete = lambda cp: shutil.rmtree(cp.value)

        # Write fourth checkpoint
        result = write_checkpoint(trial, 4)
        runner._on_saving_result(trial, result)

        # Expect 2 checkpoints because keep_checkpoints_num = 2
        cp_dirs = get_checkpoint_dirs(trial)
        self.assertEqual(len(cp_dirs), 2, msg=f"Checkpoint dirs: {cp_dirs}")

        # Write fifth checkpoint
        result = write_checkpoint(trial, 5)
        runner._on_saving_result(trial, result)

        # Expect 2 checkpoints because keep_checkpoints_num = 2
        cp_dirs = get_checkpoint_dirs(trial)
        self.assertEqual(len(cp_dirs), 2, msg=f"Checkpoint dirs: {cp_dirs}")

        # Checkpoints before restore should be deleted
        self.assertIn("checkpoint_000004", cp_dirs)
        self.assertIn("checkpoint_000005", cp_dirs)

        self.assertNotIn("checkpoint_000002", cp_dirs)
        self.assertNotIn("checkpoint_000003", cp_dirs)