def testExperimentTagTruncation(self): ray.init() def train(config, reporter): reporter(timesteps_total=1) trial_executor = RayTrialExecutor() register_trainable("f1", train) experiments = { "foo": { "run": "f1", "config": { "a" * 50: lambda spec: 5.0 / 7, "b" * 50: lambda spec: "long" * 40 }, } } for name, spec in experiments.items(): trial_generator = BasicVariantGenerator() trial_generator.add_configurations({name: spec}) for trial in trial_generator.next_trials(): trial_executor.start_trial(trial) self.assertLessEqual(len(trial.logdir), 200) trial_executor.stop_trial(trial)
def testExperimentTagTruncation(self): ray.init(num_cpus=2) def train(config, reporter): reporter(timesteps_total=1) trial_executor = RayTrialExecutor() register_trainable("f1", train) experiments = { "foo": { "run": "f1", "config": { "a" * 50: tune.sample_from(lambda spec: 5.0 / 7), "b" * 50: tune.sample_from(lambda spec: "long" * 40), }, } } for name, spec in experiments.items(): trial_generator = BasicVariantGenerator() trial_generator.add_configurations({name: spec}) while not trial_generator.is_finished(): trial = trial_generator.next_trial() if not trial: break trial_executor.start_trial(trial) self.assertLessEqual(len(os.path.basename(trial.logdir)), 200) trial_executor.stop_trial(trial)
def testExperimentTagTruncation(self): ray.init(num_cpus=2) trainable_cls = AdaptDLTrainableCreator(_train_simple, num_workers=1) trial_executor = RayTrialExecutor() experiments = { "foo": { "run": trainable_cls.__name__, "config": { "a" * 50: tune.sample_from(lambda spec: 5.0 / 7), "b" * 50: tune.sample_from(lambda spec: "long" * 40) }, } } for name, spec in experiments.items(): trial_generator = BasicVariantGenerator() trial_generator.add_configurations({name: spec}) while not trial_generator.is_finished(): trial = trial_generator.next_trial() if not trial: break trial_executor.start_trial(trial) assert trial.status == Trial.RUNNING assert len(os.path.basename(trial.logdir)) <= 200 trial_executor.stop_trial(trial) assert trial.status == Trial.TERMINATED
def testTrialErrorOnStart(self): ray.init() trial_executor = RayTrialExecutor() _global_registry.register(TRAINABLE_CLASS, "asdf", None) trial = Trial("asdf", resources=Resources(1, 0)) try: trial_executor.start_trial(trial) except Exception as e: self.assertIn("a class", str(e))
def testTrialStatus(self): ray.init() trial = Trial("__fake") trial_executor = RayTrialExecutor() self.assertEqual(trial.status, Trial.PENDING) trial_executor.start_trial(trial) self.assertEqual(trial.status, Trial.RUNNING) trial_executor.stop_trial(trial) self.assertEqual(trial.status, Trial.TERMINATED) trial_executor.stop_trial(trial, error=True) self.assertEqual(trial.status, Trial.ERROR)
def testTrialStatus(self): ray.init(num_cpus=2) trainable_cls = AdaptDLTrainableCreator(_train_simple, num_workers=2) trial = AdaptDLTrial(trainable_cls.__name__, trial_id="0") trial_executor = RayTrialExecutor() assert trial.status == Trial.PENDING trial_executor.start_trial(trial) assert trial.status == Trial.RUNNING trial_executor.stop_trial(trial) assert trial.status == Trial.TERMINATED trial_executor.stop_trial(trial, error=True) assert trial.status == Trial.ERROR
def _exploit_trial(self, trial_executor: RayTrialExecutor, trial: Trial, trial_to_clone: Trial): """ Transfers perturbed state from trial_to_clone -> trial. If specified, also logs the updated hyperparam state. """ trial_state = self._trials_states_dict[trial] new_state = self._trials_states_dict[trial_to_clone] if not new_state.last_checkpoint: logger.info( "[pbt]: no checkpoint for trial. Skip exploit for Trial {}". format(trial)) return new_config = explore(trial_to_clone.config, self._hyperparam_mutations, self._hyperparam_mutate_probability, self._explore_func) logger.info( "[exploit] transferring weights from trial {} (score {}) -> {} (score {})" .format(trial_to_clone, new_state.last_score, trial, trial_state.last_score)) if self._log_config: self._log_config_on_step(trial_state, new_state, trial, trial_to_clone, new_config) new_tag = make_experiment_tag(trial_state.orig_tag, new_config, self._hyperparam_mutations) reset_successful = trial_executor.reset_trial(trial, new_config, new_tag) if reset_successful: trial_executor.restore( trial, Checkpoint.from_object(new_state.last_checkpoint)) else: trial_executor.stop_trial(trial, stop_logger=False) trial.config = new_config trial.experiment_tag = new_tag trial_executor.start_trial( trial, Checkpoint.from_object(new_state.last_checkpoint)) # TODO: move to Exploiter new_state.num_steps = 0 trial_state.num_steps = 0 new_state.num_explorations = 0 trial_state.num_explorations += 1 self._num_explorations += 1 # Transfer over the last perturbation time as well trial_state.last_perturbation_time = new_state.last_perturbation_time
class RayTrialExecutorTest(unittest.TestCase): def setUp(self): self.trial_executor = RayTrialExecutor(queue_trials=False) ray.init() def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects def _get_trials(self): trials = self.generate_trials( { "run": "PPO", "config": { "bar": { "grid_search": [True, False] }, "foo": { "grid_search": [1, 2, 3] }, }, }, "grid_search") return list(trials) def testStartStop(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) running = self.trial_executor.get_running_trials() self.assertEqual(1, len(running)) self.trial_executor.stop_trial(trial) def testSaveRestore(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.save(trial, Checkpoint.DISK) self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def generate_trials(self, spec, name): suggester = BasicVariantGenerator({name: spec}) return suggester.next_trials()
class RayTrialExecutorTest(unittest.TestCase): def setUp(self): self.trial_executor = RayTrialExecutor() ray.init(num_cpus=2, ignore_reinit_error=True) _register_all() # Needed for flaky tests def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects def _simulate_starting_trial(self, trial): future_result = self.trial_executor.get_next_executor_event( live_trials={trial}, next_trial_exists=True) assert future_result.type == ExecutorEventType.PG_READY self.assertTrue(self.trial_executor.start_trial(trial)) self.assertEqual(Trial.RUNNING, trial.status) def _simulate_getting_result(self, trial): while True: future_result = self.trial_executor.get_next_executor_event( live_trials={trial}, next_trial_exists=False) if future_result.type == ExecutorEventType.TRAINING_RESULT: break if isinstance(future_result.result, list): for r in future_result.result: trial.update_last_result(r) else: trial.update_last_result(future_result.result) def _simulate_saving(self, trial): checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.assertEqual(checkpoint, trial.saving_to) self.assertEqual(trial.checkpoint.value, None) future_result = self.trial_executor.get_next_executor_event( live_trials={trial}, next_trial_exists=False) assert future_result.type == ExecutorEventType.SAVING_RESULT self.process_trial_save(trial, future_result.result) self.assertEqual(checkpoint, trial.checkpoint) def testStartStop(self): trial = Trial("__fake") self._simulate_starting_trial(trial) self.trial_executor.stop_trial(trial) def testAsyncSave(self): """Tests that saved checkpoint value not immediately set.""" trial = Trial("__fake") self._simulate_starting_trial(trial) self._simulate_getting_result(trial) self._simulate_saving(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testSaveRestore(self): trial = Trial("__fake") self._simulate_starting_trial(trial) self._simulate_getting_result(trial) self._simulate_saving(trial) self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testPauseResume(self): """Tests that pausing works for trials in flight.""" trial = Trial("__fake") self._simulate_starting_trial(trial) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self._simulate_starting_trial(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testSavePauseResumeErrorRestore(self): """Tests that pause checkpoint does not replace restore checkpoint.""" trial = Trial("__fake") self._simulate_starting_trial(trial) self._simulate_getting_result(trial) # Save self._simulate_saving(trial) # Train self.trial_executor.continue_training(trial) self._simulate_getting_result(trial) # Pause self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY) # Resume self._simulate_starting_trial(trial) # Error trial.set_status(Trial.ERROR) # Restore self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testStartFailure(self): _global_registry.register(TRAINABLE_CLASS, "asdf", None) trial = Trial("asdf", resources=Resources(1, 0)) self.trial_executor.start_trial(trial) self.assertEqual(Trial.ERROR, trial.status) def testPauseResume2(self): """Tests that pausing works for trials being processed.""" trial = Trial("__fake") self._simulate_starting_trial(trial) self._simulate_getting_result(trial) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self._simulate_starting_trial(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def _testPauseAndStart(self, result_buffer_length): """Tests that unpausing works for trials being processed.""" os.environ["TUNE_RESULT_BUFFER_LENGTH"] = f"{result_buffer_length}" os.environ["TUNE_RESULT_BUFFER_MIN_TIME_S"] = "1" # Need a new trial executor so the ENV vars are parsed again self.trial_executor = RayTrialExecutor() base = max(result_buffer_length, 1) trial = Trial("__fake") self._simulate_starting_trial(trial) self._simulate_getting_result(trial) self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self._simulate_starting_trial(trial) self._simulate_getting_result(trial) self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base * 2) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testPauseAndStartNoBuffer(self): self._testPauseAndStart(0) def testPauseAndStartTrivialBuffer(self): self._testPauseAndStart(1) def testPauseAndStartActualBuffer(self): self._testPauseAndStart(8) def testNoResetTrial(self): """Tests that reset handles NotImplemented properly.""" trial = Trial("__fake") self._simulate_starting_trial(trial) exists = self.trial_executor.reset_trial(trial, {}, "modified_mock") self.assertEqual(exists, False) self.assertEqual(Trial.RUNNING, trial.status) def testResetTrial(self): """Tests that reset works as expected.""" class B(Trainable): def step(self): return dict(timesteps_this_iter=1, done=True) def reset_config(self, config): self.config = config return True trials = self.generate_trials( { "run": B, "config": { "foo": 0 }, }, "grid_search", ) trial = trials[0] self._simulate_starting_trial(trial) exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock") self.assertEqual(exists, True) self.assertEqual(trial.config.get("hi"), 1) self.assertEqual(trial.experiment_tag, "modified_mock") self.assertEqual(Trial.RUNNING, trial.status) def testTrialCleanup(self): class B(Trainable): def step(self): print("Step start") time.sleep(4) print("Step done") return dict(my_metric=1, timesteps_this_iter=1, done=True) def reset_config(self, config): self.config = config return True def cleanup(self): print("Cleanup start") time.sleep(4) print("Cleanup done") # First check if the trials terminate gracefully by default trials = self.generate_trials( { "run": B, "config": { "foo": 0 }, }, "grid_search", ) trial = trials[0] self._simulate_starting_trial(trial) time.sleep(1) print("Stop trial") self.trial_executor.stop_trial(trial) print("Start trial cleanup") start = time.time() self.trial_executor.cleanup([trial]) # 4 - 1 + 4. self.assertGreaterEqual(time.time() - start, 6) # Check forceful termination. It should run for much less than the # sleep periods in the Trainable trials = self.generate_trials( { "run": B, "config": { "foo": 0 }, }, "grid_search", ) trial = trials[0] os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1" self.trial_executor = RayTrialExecutor() os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0" self._simulate_starting_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) time.sleep(1) print("Stop trial") self.trial_executor.stop_trial(trial) print("Start trial cleanup") start = time.time() self.trial_executor.cleanup([trial]) # less than 1 with some margin. self.assertLess(time.time() - start, 2.0) # also check if auto-filled metrics were returned self.assertIn(PID, trial.last_result) self.assertIn(TRIAL_ID, trial.last_result) self.assertNotIn("my_metric", trial.last_result) @staticmethod def generate_trials(spec, name): suggester = BasicVariantGenerator() suggester.add_configurations({name: spec}) trials = [] while not suggester.is_finished(): trial = suggester.next_trial() if trial: trials.append(trial) else: break return trials def process_trial_save(self, trial, checkpoint_value): """Simulates trial runner save.""" checkpoint = trial.saving_to checkpoint.value = checkpoint_value trial.on_checkpoint(checkpoint)
class RayTrialExecutorTest(unittest.TestCase): def setUp(self): # Wait up to five seconds for placement groups when starting a trial os.environ["TUNE_PLACEMENT_GROUP_WAIT_S"] = "5" # Block for results even when placement groups are pending os.environ["TUNE_TRIAL_STARTUP_GRACE_PERIOD"] = "0" os.environ["TUNE_TRIAL_RESULT_WAIT_TIME_S"] = "99999" self.trial_executor = RayTrialExecutor(queue_trials=False) ray.init(num_cpus=2, ignore_reinit_error=True) _register_all() # Needed for flaky tests def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects def testStartStop(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) running = self.trial_executor.get_running_trials() self.assertEqual(1, len(running)) self.trial_executor.stop_trial(trial) def testAsyncSave(self): """Tests that saved checkpoint value not immediately set.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial)[-1] checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.assertEqual(checkpoint, trial.saving_to) self.assertEqual(trial.checkpoint.value, None) self.process_trial_save(trial) self.assertEqual(checkpoint, trial.checkpoint) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testSaveRestore(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial)[-1] self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.process_trial_save(trial) self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testPauseResume(self): """Tests that pausing works for trials in flight.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testSavePauseResumeErrorRestore(self): """Tests that pause checkpoint does not replace restore checkpoint.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) trial.last_result = self.trial_executor.fetch_result(trial)[-1] # Save checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.assertEqual(Trial.RUNNING, trial.status) self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT) # Process save result (simulates trial runner) self.process_trial_save(trial) # Train self.trial_executor.continue_training(trial) trial.last_result = self.trial_executor.fetch_result(trial)[-1] # Pause self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY) # Resume self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) # Error trial.set_status(Trial.ERROR) # Restore self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testStartFailure(self): _global_registry.register(TRAINABLE_CLASS, "asdf", None) trial = Trial("asdf", resources=Resources(1, 0)) self.trial_executor.start_trial(trial) self.assertEqual(Trial.ERROR, trial.status) def testPauseResume2(self): """Tests that pausing works for trials being processed.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.fetch_result(trial) checkpoint = self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.start_trial(trial, checkpoint) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def _testPauseUnpause(self, result_buffer_length): """Tests that unpausing works for trials being processed.""" os.environ["TUNE_RESULT_BUFFER_LENGTH"] = f"{result_buffer_length}" os.environ["TUNE_RESULT_BUFFER_MIN_TIME_S"] = "1" # Need a new trial executor so the ENV vars are parsed again self.trial_executor = RayTrialExecutor(queue_trials=False) base = max(result_buffer_length, 1) trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial)[-1] self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.unpause_trial(trial) self.assertEqual(Trial.PENDING, trial.status) self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial)[-1] self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base * 2) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testPauseUnpauseNoBuffer(self): self._testPauseUnpause(0) def testPauseUnpauseTrivialBuffer(self): self._testPauseUnpause(1) def testPauseUnpauseActualBuffer(self): self._testPauseUnpause(8) def testNoResetTrial(self): """Tests that reset handles NotImplemented properly.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) exists = self.trial_executor.reset_trial(trial, {}, "modified_mock") self.assertEqual(exists, False) self.assertEqual(Trial.RUNNING, trial.status) def testResetTrial(self): """Tests that reset works as expected.""" class B(Trainable): def step(self): return dict(timesteps_this_iter=1, done=True) def reset_config(self, config): self.config = config return True trials = self.generate_trials({ "run": B, "config": { "foo": 0 }, }, "grid_search") trial = trials[0] self.trial_executor.start_trial(trial) exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock") self.assertEqual(exists, True) self.assertEqual(trial.config.get("hi"), 1) self.assertEqual(trial.experiment_tag, "modified_mock") self.assertEqual(Trial.RUNNING, trial.status) def testForceTrialCleanup(self): class B(Trainable): def step(self): print("Step start") time.sleep(10) print("Step done") return dict(my_metric=1, timesteps_this_iter=1, done=True) def reset_config(self, config): self.config = config return True def cleanup(self): print("Cleanup start") time.sleep(10) print("Cleanup done") # First check if the trials terminate gracefully by default trials = self.generate_trials({ "run": B, "config": { "foo": 0 }, }, "grid_search") trial = trials[0] self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) time.sleep(5) print("Stop trial") self.trial_executor.stop_trial(trial) print("Start trial cleanup") start = time.time() self.trial_executor.cleanup([trial]) self.assertGreaterEqual(time.time() - start, 12.0) # Check forceful termination. It should run for much less than the # sleep periods in the Trainable trials = self.generate_trials({ "run": B, "config": { "foo": 0 }, }, "grid_search") trial = trials[0] os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "1" self.trial_executor = RayTrialExecutor(queue_trials=False) os.environ["TUNE_FORCE_TRIAL_CLEANUP_S"] = "0" self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) time.sleep(5) print("Stop trial") self.trial_executor.stop_trial(trial) print("Start trial cleanup") start = time.time() self.trial_executor.cleanup([trial]) self.assertLess(time.time() - start, 5.0) # also check if auto-filled metrics were returned self.assertIn(PID, trial.last_result) self.assertIn(TRIAL_ID, trial.last_result) self.assertNotIn("my_metric", trial.last_result) @staticmethod def generate_trials(spec, name): suggester = BasicVariantGenerator() suggester.add_configurations({name: spec}) trials = [] while not suggester.is_finished(): trial = suggester.next_trial() if trial: trials.append(trial) else: break return trials def process_trial_save(self, trial): """Simulates trial runner save.""" checkpoint = trial.saving_to checkpoint_value = self.trial_executor.fetch_result(trial)[-1] checkpoint.value = checkpoint_value trial.on_checkpoint(checkpoint)
class RayExecutorQueueTest(unittest.TestCase): def setUp(self): self.cluster = Cluster(initialize_head=True, connect=True, head_node_args={ "num_cpus": 1, "_system_config": { "num_heartbeats_timeout": 10 } }) self.trial_executor = RayTrialExecutor(queue_trials=True, refresh_period=0) # Pytest doesn't play nicely with imports _register_all() def tearDown(self): ray.shutdown() self.cluster.shutdown() _register_all() # re-register the evicted objects def testQueueTrial(self): """Tests that reset handles NotImplemented properly.""" def create_trial(cpu, gpu=0): return Trial("__fake", resources=Resources(cpu=cpu, gpu=gpu)) cpu_only = create_trial(1, 0) self.assertTrue(self.trial_executor.has_resources_for_trial(cpu_only)) self.trial_executor.start_trial(cpu_only) gpu_only = create_trial(0, 1) self.assertTrue(self.trial_executor.has_resources_for_trial(gpu_only)) def testHeadBlocking(self): # Once resource requests are deprecated, remove this test os.environ["TUNE_PLACEMENT_GROUP_AUTO_DISABLED"] = "1" def create_trial(cpu, gpu=0): return Trial("__fake", resources=Resources(cpu=cpu, gpu=gpu)) gpu_trial = create_trial(1, 1) self.assertTrue(self.trial_executor.has_resources_for_trial(gpu_trial)) self.trial_executor.start_trial(gpu_trial) # TODO(rliaw): This behavior is probably undesirable, but right now # trials with different resource requirements is not often used. cpu_only_trial = create_trial(1, 0) self.assertFalse( self.trial_executor.has_resources_for_trial(cpu_only_trial)) self.cluster.add_node(num_cpus=1, num_gpus=1) self.cluster.wait_for_nodes() self.assertTrue( self.trial_executor.has_resources_for_trial(cpu_only_trial)) self.trial_executor.start_trial(cpu_only_trial) cpu_only_trial2 = create_trial(1, 0) self.assertTrue( self.trial_executor.has_resources_for_trial(cpu_only_trial2)) self.trial_executor.start_trial(cpu_only_trial2) cpu_only_trial3 = create_trial(1, 0) self.assertFalse( self.trial_executor.has_resources_for_trial(cpu_only_trial3))
class RayTrialExecutorTest(unittest.TestCase): def setUp(self): self.trial_executor = RayTrialExecutor(queue_trials=False) ray.init() _register_all() # Needed for flaky tests def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects def testStartStop(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) running = self.trial_executor.get_running_trials() self.assertEqual(1, len(running)) self.trial_executor.stop_trial(trial) def testSaveRestore(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.save(trial, Checkpoint.DISK) self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testPauseResume(self): """Tests that pausing works for trials in flight.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testStartFailure(self): _global_registry.register(TRAINABLE_CLASS, "asdf", None) trial = Trial("asdf", resources=Resources(1, 0)) self.trial_executor.start_trial(trial) self.assertEqual(Trial.ERROR, trial.status) def testPauseResume2(self): """Tests that pausing works for trials being processed.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.fetch_result(trial) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testNoResetTrial(self): """Tests that reset handles NotImplemented properly.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) exists = self.trial_executor.reset_trial(trial, {}, "modified_mock") self.assertEqual(exists, False) self.assertEqual(Trial.RUNNING, trial.status) def testResetTrial(self): """Tests that reset works as expected.""" class B(Trainable): def _train(self): return dict(timesteps_this_iter=1, done=True) def reset_config(self, config): self.config = config return True trials = self.generate_trials({ "run": B, "config": { "foo": 0 }, }, "grid_search") trial = trials[0] self.trial_executor.start_trial(trial) exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock") self.assertEqual(exists, True) self.assertEqual(trial.config.get("hi"), 1) self.assertEqual(trial.experiment_tag, "modified_mock") self.assertEqual(Trial.RUNNING, trial.status) def generate_trials(self, spec, name): suggester = BasicVariantGenerator() suggester.add_configurations({name: spec}) return suggester.next_trials()
class RayTrialExecutorTest(unittest.TestCase): def setUp(self): self.trial_executor = RayTrialExecutor(queue_trials=False) ray.init() _register_all() # Needed for flaky tests def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects def testStartStop(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) running = self.trial_executor.get_running_trials() self.assertEqual(1, len(running)) self.trial_executor.stop_trial(trial) def testAsyncSave(self): """Tests that saved checkpoint value not immediately set.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial) checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.assertEqual(checkpoint, trial.saving_to) self.assertEqual(trial.checkpoint.value, None) self.process_trial_save(trial) self.assertEqual(checkpoint, trial.checkpoint) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testSaveRestore(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial) self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.process_trial_save(trial) self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testPauseResume(self): """Tests that pausing works for trials in flight.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testSavePauseResumeErrorRestore(self): """Tests that pause checkpoint does not replace restore checkpoint.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) trial.last_result = self.trial_executor.fetch_result(trial) # Save checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.assertEqual(Trial.RUNNING, trial.status) self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT) # Process save result (simulates trial runner) self.process_trial_save(trial) # Train self.trial_executor.continue_training(trial) trial.last_result = self.trial_executor.fetch_result(trial) # Pause self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY) # Resume self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) # Error trial.set_status(Trial.ERROR) # Restore self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testStartFailure(self): _global_registry.register(TRAINABLE_CLASS, "asdf", None) trial = Trial("asdf", resources=Resources(1, 0)) self.trial_executor.start_trial(trial) self.assertEqual(Trial.ERROR, trial.status) def testPauseResume2(self): """Tests that pausing works for trials being processed.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.fetch_result(trial) checkpoint = self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.start_trial(trial, checkpoint) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testPauseUnpause(self): """Tests that unpausing works for trials being processed.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial) self.assertEqual(trial.last_result.get(TRAINING_ITERATION), 1) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.unpause_trial(trial) self.assertEqual(Trial.PENDING, trial.status) self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) trial.last_result = self.trial_executor.fetch_result(trial) self.assertEqual(trial.last_result.get(TRAINING_ITERATION), 2) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testNoResetTrial(self): """Tests that reset handles NotImplemented properly.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) exists = self.trial_executor.reset_trial(trial, {}, "modified_mock") self.assertEqual(exists, False) self.assertEqual(Trial.RUNNING, trial.status) def testResetTrial(self): """Tests that reset works as expected.""" class B(Trainable): def step(self): return dict(timesteps_this_iter=1, done=True) def reset_config(self, config): self.config = config return True trials = self.generate_trials({ "run": B, "config": { "foo": 0 }, }, "grid_search") trial = trials[0] self.trial_executor.start_trial(trial) exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock") self.assertEqual(exists, True) self.assertEqual(trial.config.get("hi"), 1) self.assertEqual(trial.experiment_tag, "modified_mock") self.assertEqual(Trial.RUNNING, trial.status) @staticmethod def generate_trials(spec, name): suggester = BasicVariantGenerator() suggester.add_configurations({name: spec}) return suggester.next_trials() def process_trial_save(self, trial): """Simulates trial runner save.""" checkpoint = trial.saving_to checkpoint_value = self.trial_executor.fetch_result(trial) checkpoint.value = checkpoint_value trial.on_checkpoint(checkpoint)
class RayTrialExecutorTest(unittest.TestCase): def setUp(self): self.trial_executor = RayTrialExecutor(queue_trials=False) ray.init() def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects def testStartStop(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) running = self.trial_executor.get_running_trials() self.assertEqual(1, len(running)) self.trial_executor.stop_trial(trial) def testSaveRestore(self): trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.save(trial, Checkpoint.DISK) self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testPauseResume(self): """Tests that pausing works for trials in flight.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testStartFailure(self): _global_registry.register(TRAINABLE_CLASS, "asdf", None) trial = Trial("asdf", resources=Resources(1, 0)) self.trial_executor.start_trial(trial) self.assertEqual(Trial.ERROR, trial.status) def testPauseResume2(self): """Tests that pausing works for trials being processed.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.fetch_result(trial) self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) def testNoResetTrial(self): """Tests that reset handles NotImplemented properly.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) exists = self.trial_executor.reset_trial(trial, {}, "modified_mock") self.assertEqual(exists, False) self.assertEqual(Trial.RUNNING, trial.status) def testResetTrial(self): """Tests that reset works as expected.""" class B(Trainable): def _train(self): return dict(timesteps_this_iter=1, done=True) def reset_config(self, config): self.config = config return True trials = self.generate_trials({ "run": B, "config": { "foo": 0 }, }, "grid_search") trial = trials[0] self.trial_executor.start_trial(trial) exists = self.trial_executor.reset_trial(trial, {"hi": 1}, "modified_mock") self.assertEqual(exists, True) self.assertEqual(trial.config.get("hi"), 1) self.assertEqual(trial.experiment_tag, "modified_mock") self.assertEqual(Trial.RUNNING, trial.status) def generate_trials(self, spec, name): suggester = BasicVariantGenerator() suggester.add_configurations({name: spec}) return suggester.next_trials()