def testSaveRestore(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) self.trial_executor.on_no_available_trials(None) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial) self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.process_trial_save(trial) self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status)
def testBestTrialStr(self): """Assert that custom nested parameter columns are printed correctly""" config = {"nested": {"conf": "nested_value"}, "toplevel": "toplevel_value"} trial = Trial("", config=config, stub=True) trial.last_result = {"metric": 1, "config": config} result = best_trial_str(trial, "metric") self.assertIn("nested_value", result) result = best_trial_str(trial, "metric", parameter_columns=["nested/conf"]) self.assertIn("nested_value", result)
def testAsyncSave(self): """Tests that saved checkpoint value not immediately set.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial)[-1] checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.assertEqual(checkpoint, trial.saving_to) self.assertEqual(trial.checkpoint.value, None) self.process_trial_save(trial) self.assertEqual(checkpoint, trial.checkpoint) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status)
def _testPauseAndStart(self, result_buffer_length): """Tests that unpausing works for trials being processed.""" os.environ["TUNE_RESULT_BUFFER_LENGTH"] = f"{result_buffer_length}" os.environ["TUNE_RESULT_BUFFER_MIN_TIME_S"] = "1" # Need a new trial executor so the ENV vars are parsed again self.trial_executor = RayTrialExecutor() base = max(result_buffer_length, 1) trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial)[-1] self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial)[-1] self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base * 2) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status)