def testTrialSaveRestore(self): """Creates different trials to test runner.checkpoint/restore.""" ray.init(num_cpus=3) tmpdir = tempfile.mkdtemp() runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0) trials = [ Trial("__fake", trial_id="trial_terminate", stopping_criterion={"training_iteration": 1}, checkpoint_freq=1) ] runner.add_trial(trials[0]) runner.step() # Start trial runner.step() # Process result, dispatch save runner.step() # Process save self.assertEquals(trials[0].status, Trial.TERMINATED) trials += [ Trial("__fake", trial_id="trial_fail", stopping_criterion={"training_iteration": 3}, checkpoint_freq=1, config={"mock_error": True}) ] runner.add_trial(trials[1]) runner.step() # Start trial runner.step() # Process result, dispatch save runner.step() # Process save runner.step() # Error self.assertEquals(trials[1].status, Trial.ERROR) trials += [ Trial("__fake", trial_id="trial_succ", stopping_criterion={"training_iteration": 2}, checkpoint_freq=1) ] runner.add_trial(trials[2]) runner.step() # Start trial self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3) self.assertEquals(trials[2].status, Trial.RUNNING) runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir) for tid in ["trial_terminate", "trial_fail"]: original_trial = runner.get_trial(tid) restored_trial = runner2.get_trial(tid) self.assertEqual(original_trial.status, restored_trial.status) restored_trial = runner2.get_trial("trial_succ") self.assertEqual(Trial.PENDING, restored_trial.status) runner2.step() # Start trial runner2.step() # Process result, dispatch save runner2.step() # Process save runner2.step() # Process result, dispatch save runner2.step() # Process save self.assertRaises(TuneError, runner2.step) shutil.rmtree(tmpdir)
def testTrialNoCheckpointSave(self): """Check that non-checkpointing trials *are* saved.""" os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" ray.init(num_cpus=3) runner = TrialRunner(local_checkpoint_dir=self.tmpdir, checkpoint_period=0) runner.add_trial( Trial( "__fake", trial_id="non_checkpoint", stopping_criterion={"training_iteration": 2}, )) while not all(t.status == Trial.TERMINATED for t in runner.get_trials()): runner.step() runner.add_trial( Trial( "__fake", trial_id="checkpoint", checkpoint_at_end=True, stopping_criterion={"training_iteration": 2}, )) while not all(t.status == Trial.TERMINATED for t in runner.get_trials()): runner.step() runner.add_trial( Trial( "__fake", trial_id="pending", stopping_criterion={"training_iteration": 2}, )) old_trials = runner.get_trials() while not old_trials[2].has_reported_at_least_once: runner.step() runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir) new_trials = runner2.get_trials() self.assertEqual(len(new_trials), 3) self.assertTrue( runner2.get_trial("non_checkpoint").status == Trial.TERMINATED) self.assertTrue( runner2.get_trial("checkpoint").status == Trial.TERMINATED) self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING) self.assertTrue( runner2.get_trial("pending").has_reported_at_least_once) runner2.step()
def testTrialNoSave(self): """Check that non-checkpointing trials are not saved.""" ray.init(num_cpus=3) tmpdir = tempfile.mkdtemp() runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0) runner.add_trial( Trial( "__fake", trial_id="non_checkpoint", stopping_criterion={"training_iteration": 2})) while not all(t.status == Trial.TERMINATED for t in runner.get_trials()): runner.step() runner.add_trial( Trial( "__fake", trial_id="checkpoint", checkpoint_at_end=True, stopping_criterion={"training_iteration": 2})) while not all(t.status == Trial.TERMINATED for t in runner.get_trials()): runner.step() runner.add_trial( Trial( "__fake", trial_id="pending", stopping_criterion={"training_iteration": 2})) runner.step() runner.step() runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir) new_trials = runner2.get_trials() self.assertEquals(len(new_trials), 3) self.assertTrue( runner2.get_trial("non_checkpoint").status == Trial.TERMINATED) self.assertTrue( runner2.get_trial("checkpoint").status == Trial.TERMINATED) self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING) self.assertTrue(not runner2.get_trial("pending").last_result) runner2.step() shutil.rmtree(tmpdir)
def testTrialSaveRestore(self): """Creates different trials to test runner.checkpoint/restore.""" ray.init(num_cpus=3) runner = TrialRunner(local_checkpoint_dir=self.tmpdir, checkpoint_period=0) trials = [ Trial( "__fake", trial_id="trial_terminate", stopping_criterion={"training_iteration": 1}, checkpoint_freq=1, ) ] runner.add_trial(trials[0]) while not runner.is_finished(): # Start trial, process result, dispatch save and process save. runner.step() self.assertEqual(trials[0].status, Trial.TERMINATED) trials += [ Trial( "__fake", trial_id="trial_fail", stopping_criterion={"training_iteration": 3}, checkpoint_freq=1, config={"mock_error": True}, ) ] runner.add_trial(trials[1]) while not runner.is_finished(): # Start trial, # Process result, # Dispatch save, # Process save and # Error. runner.step() self.assertEqual(trials[1].status, Trial.ERROR) trials += [ Trial( "__fake", trial_id="trial_succ", stopping_criterion={"training_iteration": 2}, checkpoint_freq=1, ) ] runner.add_trial(trials[2]) runner.step() # Start trial self.assertEqual(len(runner.trial_executor.get_checkpoints()), 3) self.assertEqual(trials[2].status, Trial.RUNNING) runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir) for tid in ["trial_terminate", "trial_fail"]: original_trial = runner.get_trial(tid) restored_trial = runner2.get_trial(tid) self.assertEqual(original_trial.status, restored_trial.status) restored_trial = runner2.get_trial("trial_succ") self.assertEqual(Trial.PENDING, restored_trial.status) while not runner2.is_finished(): # Start trial, # Process result, dispatch save # Process save # Process result, dispatch save # Process save. runner2.step() self.assertEqual(restored_trial.status, Trial.TERMINATED)