Example #1
0
    def testTrialSaveRestore(self):
        """Creates different trials to test runner.checkpoint/restore."""
        ray.init(num_cpus=3)
        tmpdir = tempfile.mkdtemp()

        runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
        trials = [
            Trial("__fake",
                  trial_id="trial_terminate",
                  stopping_criterion={"training_iteration": 1},
                  checkpoint_freq=1)
        ]
        runner.add_trial(trials[0])
        runner.step()  # Start trial
        runner.step()  # Process result, dispatch save
        runner.step()  # Process save
        self.assertEquals(trials[0].status, Trial.TERMINATED)

        trials += [
            Trial("__fake",
                  trial_id="trial_fail",
                  stopping_criterion={"training_iteration": 3},
                  checkpoint_freq=1,
                  config={"mock_error": True})
        ]
        runner.add_trial(trials[1])
        runner.step()  # Start trial
        runner.step()  # Process result, dispatch save
        runner.step()  # Process save
        runner.step()  # Error
        self.assertEquals(trials[1].status, Trial.ERROR)

        trials += [
            Trial("__fake",
                  trial_id="trial_succ",
                  stopping_criterion={"training_iteration": 2},
                  checkpoint_freq=1)
        ]
        runner.add_trial(trials[2])
        runner.step()  # Start trial
        self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3)
        self.assertEquals(trials[2].status, Trial.RUNNING)

        runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
        for tid in ["trial_terminate", "trial_fail"]:
            original_trial = runner.get_trial(tid)
            restored_trial = runner2.get_trial(tid)
            self.assertEqual(original_trial.status, restored_trial.status)

        restored_trial = runner2.get_trial("trial_succ")
        self.assertEqual(Trial.PENDING, restored_trial.status)

        runner2.step()  # Start trial
        runner2.step()  # Process result, dispatch save
        runner2.step()  # Process save
        runner2.step()  # Process result, dispatch save
        runner2.step()  # Process save
        self.assertRaises(TuneError, runner2.step)
        shutil.rmtree(tmpdir)
Example #2
0
    def testTrialNoCheckpointSave(self):
        """Check that non-checkpointing trials *are* saved."""
        os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1"

        ray.init(num_cpus=3)

        runner = TrialRunner(local_checkpoint_dir=self.tmpdir,
                             checkpoint_period=0)
        runner.add_trial(
            Trial(
                "__fake",
                trial_id="non_checkpoint",
                stopping_criterion={"training_iteration": 2},
            ))

        while not all(t.status == Trial.TERMINATED
                      for t in runner.get_trials()):
            runner.step()

        runner.add_trial(
            Trial(
                "__fake",
                trial_id="checkpoint",
                checkpoint_at_end=True,
                stopping_criterion={"training_iteration": 2},
            ))

        while not all(t.status == Trial.TERMINATED
                      for t in runner.get_trials()):
            runner.step()

        runner.add_trial(
            Trial(
                "__fake",
                trial_id="pending",
                stopping_criterion={"training_iteration": 2},
            ))

        old_trials = runner.get_trials()
        while not old_trials[2].has_reported_at_least_once:
            runner.step()

        runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
        new_trials = runner2.get_trials()
        self.assertEqual(len(new_trials), 3)
        self.assertTrue(
            runner2.get_trial("non_checkpoint").status == Trial.TERMINATED)
        self.assertTrue(
            runner2.get_trial("checkpoint").status == Trial.TERMINATED)
        self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING)
        self.assertTrue(
            runner2.get_trial("pending").has_reported_at_least_once)
        runner2.step()
Example #3
0
    def testTrialNoSave(self):
        """Check that non-checkpointing trials are not saved."""
        ray.init(num_cpus=3)
        tmpdir = tempfile.mkdtemp()

        runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
        runner.add_trial(
            Trial(
                "__fake",
                trial_id="non_checkpoint",
                stopping_criterion={"training_iteration": 2}))

        while not all(t.status == Trial.TERMINATED
                      for t in runner.get_trials()):
            runner.step()

        runner.add_trial(
            Trial(
                "__fake",
                trial_id="checkpoint",
                checkpoint_at_end=True,
                stopping_criterion={"training_iteration": 2}))

        while not all(t.status == Trial.TERMINATED
                      for t in runner.get_trials()):
            runner.step()

        runner.add_trial(
            Trial(
                "__fake",
                trial_id="pending",
                stopping_criterion={"training_iteration": 2}))

        runner.step()
        runner.step()

        runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
        new_trials = runner2.get_trials()
        self.assertEquals(len(new_trials), 3)
        self.assertTrue(
            runner2.get_trial("non_checkpoint").status == Trial.TERMINATED)
        self.assertTrue(
            runner2.get_trial("checkpoint").status == Trial.TERMINATED)
        self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING)
        self.assertTrue(not runner2.get_trial("pending").last_result)
        runner2.step()
        shutil.rmtree(tmpdir)
Example #4
0
    def testTrialSaveRestore(self):
        """Creates different trials to test runner.checkpoint/restore."""
        ray.init(num_cpus=3)

        runner = TrialRunner(local_checkpoint_dir=self.tmpdir,
                             checkpoint_period=0)
        trials = [
            Trial(
                "__fake",
                trial_id="trial_terminate",
                stopping_criterion={"training_iteration": 1},
                checkpoint_freq=1,
            )
        ]
        runner.add_trial(trials[0])
        while not runner.is_finished():
            # Start trial, process result, dispatch save and process save.
            runner.step()
        self.assertEqual(trials[0].status, Trial.TERMINATED)

        trials += [
            Trial(
                "__fake",
                trial_id="trial_fail",
                stopping_criterion={"training_iteration": 3},
                checkpoint_freq=1,
                config={"mock_error": True},
            )
        ]
        runner.add_trial(trials[1])
        while not runner.is_finished():
            # Start trial,
            # Process result,
            # Dispatch save,
            # Process save and
            # Error.
            runner.step()
        self.assertEqual(trials[1].status, Trial.ERROR)

        trials += [
            Trial(
                "__fake",
                trial_id="trial_succ",
                stopping_criterion={"training_iteration": 2},
                checkpoint_freq=1,
            )
        ]
        runner.add_trial(trials[2])
        runner.step()  # Start trial
        self.assertEqual(len(runner.trial_executor.get_checkpoints()), 3)
        self.assertEqual(trials[2].status, Trial.RUNNING)

        runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
        for tid in ["trial_terminate", "trial_fail"]:
            original_trial = runner.get_trial(tid)
            restored_trial = runner2.get_trial(tid)
            self.assertEqual(original_trial.status, restored_trial.status)

        restored_trial = runner2.get_trial("trial_succ")
        self.assertEqual(Trial.PENDING, restored_trial.status)

        while not runner2.is_finished():
            # Start trial,
            # Process result, dispatch save
            # Process save
            # Process result, dispatch save
            # Process save.
            runner2.step()
        self.assertEqual(restored_trial.status, Trial.TERMINATED)