def testCheckpointAutoPeriod(self): ray.init(num_cpus=3) # This makes checkpointing take 2 seconds. class CustomSyncer(Syncer): def __init__(self, sync_period: float = 300.0): super(CustomSyncer, self).__init__(sync_period=sync_period) self._sync_status = {} def sync_up( self, local_dir: str, remote_dir: str, exclude: list = None ) -> bool: time.sleep(2) return True def sync_down( self, remote_dir: str, local_dir: str, exclude: list = None ) -> bool: time.sleep(2) return True def delete(self, remote_dir: str) -> bool: pass runner = TrialRunner( local_checkpoint_dir=self.tmpdir, checkpoint_period="auto", sync_config=SyncConfig( upload_dir="fake", syncer=CustomSyncer(), sync_period=0 ), remote_checkpoint_dir="fake", ) runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 1})) runner.step() # Run one step, this will trigger checkpointing self.assertGreaterEqual(runner._checkpoint_manager._checkpoint_period, 38.0)
def testRestoreMetricsAfterCheckpointing(self): ray.init(num_cpus=1, num_gpus=1) observer = TrialResultObserver() runner = TrialRunner(callbacks=[observer]) kwargs = { "stopping_criterion": {"training_iteration": 2}, "resources": Resources(cpu=1, gpu=1), "checkpoint_freq": 1, } runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() while not runner.is_finished(): runner.step() self.assertEqual(trials[0].status, Trial.TERMINATED) kwargs["restore_path"] = trials[0].checkpoint.dir_or_data kwargs.pop("stopping_criterion") kwargs.pop("checkpoint_freq") # No checkpointing for next trial runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() observer.reset() while not observer.just_received_a_result(): runner.step() self.assertEqual(trials[1].last_result["timesteps_since_restore"], 10) self.assertEqual(trials[1].last_result["iterations_since_restore"], 1) self.assertGreater(trials[1].last_result["time_since_restore"], 0) while not observer.just_received_a_result(): runner.step() self.assertEqual(trials[1].last_result["timesteps_since_restore"], 20) self.assertEqual(trials[1].last_result["iterations_since_restore"], 2) self.assertGreater(trials[1].last_result["time_since_restore"], 0) self.addCleanup(os.remove, trials[0].checkpoint.dir_or_data)
def testFailureRecoveryDisabled(self): ray.init(num_cpus=1, num_gpus=1) searchalg, scheduler = create_mock_components() runner = TrialRunner(searchalg, scheduler=scheduler) kwargs = { "resources": Resources(cpu=1, gpu=1), "checkpoint_freq": 1, "max_failures": 0, "config": { "mock_error": True, }, } runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() while not runner.is_finished(): runner.step() self.assertEqual(trials[0].status, Trial.ERROR) self.assertEqual(trials[0].num_failures, 1) self.assertEqual(len(searchalg.errored_trials), 1) self.assertEqual(len(scheduler.errored_trials), 1)
def testFailFast(self): ray.init(num_cpus=1, num_gpus=1) runner = TrialRunner(fail_fast=True) kwargs = { "resources": Resources(cpu=1, gpu=1), "checkpoint_freq": 1, "max_failures": 0, "config": { "mock_error": True, "persistent_error": True, }, } runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() while not runner.is_finished(): runner.step() self.assertEqual(trials[0].status, Trial.ERROR) # Somehow with `fail_fast=True`, if one errors out, the others are # then stopped with `TERMINATED` status. self.assertEqual(trials[1].status, Trial.TERMINATED) self.assertRaises(TuneError, lambda: runner.step())
def create_tune_experiment_checkpoint(trials: list, **runner_kwargs) -> str: experiment_dir = tempfile.mkdtemp() runner_kwargs.setdefault("local_checkpoint_dir", experiment_dir) # Update environment orig_env = os.environ.copy() # Set to 1 to disable ray cluster resource lookup. That way we can # create experiment checkpoints without initializing ray. os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" try: runner = TrialRunner(**runner_kwargs) for trial in trials: runner.add_trial(trial) runner.checkpoint(force=True) finally: os.environ.clear() os.environ.update(orig_env) return experiment_dir
def testFailFastRaise(self): ray.init(num_cpus=1, num_gpus=1) runner = TrialRunner(fail_fast=TrialRunner.RAISE) kwargs = { "resources": Resources(cpu=1, gpu=1), "checkpoint_freq": 1, "max_failures": 0, "config": { "mock_error": True, "persistent_error": True, }, } runner.add_trial(Trial("__fake", **kwargs)) runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() with self.assertRaises(Exception): while not runner.is_finished(): runner.step() # Not critical checks. Only to showcase the difference # with none raise type FailFast. self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(trials[1].status, Trial.PENDING)
def testUserCheckpoint(self): os.environ["TUNE_RESULT_BUFFER_LENGTH"] = "1" # Don't finish early os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" ray.init(num_cpus=3) runner = TrialRunner(local_checkpoint_dir=self.tmpdir, checkpoint_period=0) runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 2})) trials = runner.get_trials() runner.step() # Start trial self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) runner.step() # Process result self.assertFalse(trials[0].has_checkpoint()) runner.step() # Process result self.assertFalse(trials[0].has_checkpoint()) runner.step() # Process result, dispatch save runner.step() # Process save self.assertTrue(trials[0].has_checkpoint()) runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir) runner2.step() # 5: Start trial and dispatch restore trials2 = runner2.get_trials() self.assertEqual(ray.get(trials2[0].runner.get_info.remote()), 1)
def testTrialNoCheckpointSave(self): """Check that non-checkpointing trials *are* saved.""" os.environ["TUNE_MAX_PENDING_TRIALS_PG"] = "1" ray.init(num_cpus=3) runner = TrialRunner(local_checkpoint_dir=self.tmpdir, checkpoint_period=0) runner.add_trial( Trial( "__fake", trial_id="non_checkpoint", stopping_criterion={"training_iteration": 2}, )) while not all(t.status == Trial.TERMINATED for t in runner.get_trials()): runner.step() runner.add_trial( Trial( "__fake", trial_id="checkpoint", checkpoint_at_end=True, stopping_criterion={"training_iteration": 2}, )) while not all(t.status == Trial.TERMINATED for t in runner.get_trials()): runner.step() runner.add_trial( Trial( "__fake", trial_id="pending", stopping_criterion={"training_iteration": 2}, )) old_trials = runner.get_trials() while not old_trials[2].has_reported_at_least_once: runner.step() runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir) new_trials = runner2.get_trials() self.assertEqual(len(new_trials), 3) self.assertTrue( runner2.get_trial("non_checkpoint").status == Trial.TERMINATED) self.assertTrue( runner2.get_trial("checkpoint").status == Trial.TERMINATED) self.assertTrue(runner2.get_trial("pending").status == Trial.PENDING) self.assertTrue( runner2.get_trial("pending").has_reported_at_least_once) runner2.step()
def testPauseResumeCheckpointCount(self): ray.init(num_cpus=2) tempdir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, tempdir) trial = Trial("__fake", keep_checkpoints_num=2) trial.init_logdir() trial.checkpoint_manager.set_delete_fn( lambda cp: shutil.rmtree(cp.dir_or_data)) def write_checkpoint(trial: Trial, index: int): checkpoint_dir = TrainableUtil.make_checkpoint_dir(trial.logdir, index=index) result = {"training_iteration": index} with open(os.path.join(checkpoint_dir, "cp.json"), "w") as f: json.dump(result, f) tune_cp = _TrackedCheckpoint( dir_or_data=checkpoint_dir, storage_mode=CheckpointStorage.PERSISTENT, metrics=result, ) trial.saving_to = tune_cp return checkpoint_dir def get_checkpoint_dirs(trial: Trial): return [ d for d in os.listdir(trial.logdir) if d.startswith("checkpoint_") ] runner = TrialRunner(local_checkpoint_dir=tempdir) runner.add_trial(trial) # Write 1 checkpoint result = write_checkpoint(trial, 1) runner._on_saving_result(trial, result) # Expect 1 checkpoint cp_dirs = get_checkpoint_dirs(trial) self.assertEqual(len(cp_dirs), 1, msg=f"Checkpoint dirs: {cp_dirs}") # Write second checkpoint result = write_checkpoint(trial, 2) runner._on_saving_result(trial, result) # Expect 2 checkpoints cp_dirs = get_checkpoint_dirs(trial) self.assertEqual(len(cp_dirs), 2, msg=f"Checkpoint dirs: {cp_dirs}") # Write third checkpoint result = write_checkpoint(trial, 3) runner._on_saving_result(trial, result) # Expect 2 checkpoints because keep_checkpoints_num = 2 cp_dirs = get_checkpoint_dirs(trial) self.assertEqual(len(cp_dirs), 2, msg=f"Checkpoint dirs: {cp_dirs}") # Re-instantiate trial runner and resume runner.checkpoint(force=True) runner = TrialRunner(local_checkpoint_dir=tempdir) runner.resume() trial = runner.get_trials()[0] trial.checkpoint_manager.set_delete_fn( lambda cp: shutil.rmtree(cp.dir_or_data)) # Write fourth checkpoint result = write_checkpoint(trial, 4) runner._on_saving_result(trial, result) # Expect 2 checkpoints because keep_checkpoints_num = 2 cp_dirs = get_checkpoint_dirs(trial) self.assertEqual(len(cp_dirs), 2, msg=f"Checkpoint dirs: {cp_dirs}") # Write fifth checkpoint result = write_checkpoint(trial, 5) runner._on_saving_result(trial, result) # Expect 2 checkpoints because keep_checkpoints_num = 2 cp_dirs = get_checkpoint_dirs(trial) self.assertEqual(len(cp_dirs), 2, msg=f"Checkpoint dirs: {cp_dirs}") # Checkpoints before restore should be deleted self.assertIn("checkpoint_000004", cp_dirs) self.assertIn("checkpoint_000005", cp_dirs) self.assertNotIn("checkpoint_000002", cp_dirs) self.assertNotIn("checkpoint_000003", cp_dirs)
def testTrialSaveRestore(self): """Creates different trials to test runner.checkpoint/restore.""" ray.init(num_cpus=3) runner = TrialRunner(local_checkpoint_dir=self.tmpdir, checkpoint_period=0) trials = [ Trial( "__fake", trial_id="trial_terminate", stopping_criterion={"training_iteration": 1}, checkpoint_freq=1, ) ] runner.add_trial(trials[0]) while not runner.is_finished(): # Start trial, process result, dispatch save and process save. runner.step() self.assertEqual(trials[0].status, Trial.TERMINATED) trials += [ Trial( "__fake", trial_id="trial_fail", stopping_criterion={"training_iteration": 3}, checkpoint_freq=1, config={"mock_error": True}, ) ] runner.add_trial(trials[1]) while not runner.is_finished(): # Start trial, # Process result, # Dispatch save, # Process save and # Error. runner.step() self.assertEqual(trials[1].status, Trial.ERROR) trials += [ Trial( "__fake", trial_id="trial_succ", stopping_criterion={"training_iteration": 2}, checkpoint_freq=1, ) ] runner.add_trial(trials[2]) runner.step() # Start trial self.assertEqual(len(runner.trial_executor.get_checkpoints()), 3) self.assertEqual(trials[2].status, Trial.RUNNING) runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir) for tid in ["trial_terminate", "trial_fail"]: original_trial = runner.get_trial(tid) restored_trial = runner2.get_trial(tid) self.assertEqual(original_trial.status, restored_trial.status) restored_trial = runner2.get_trial("trial_succ") self.assertEqual(Trial.PENDING, restored_trial.status) while not runner2.is_finished(): # Start trial, # Process result, dispatch save # Process save # Process result, dispatch save # Process save. runner2.step() self.assertEqual(restored_trial.status, Trial.TERMINATED)
class TrialRunnerCallbacks(unittest.TestCase): def setUp(self): ray.init() self.tmpdir = tempfile.mkdtemp() self.callback = TestCallback() self.executor = _MockTrialExecutor() self.trial_runner = TrialRunner(trial_executor=self.executor, callbacks=[self.callback]) # experiment would never be None normally, but it's fine for testing self.trial_runner.setup_experiments(experiments=[None], total_num_samples=1) def tearDown(self): ray.shutdown() _register_all() # re-register the evicted objects if "CUDA_VISIBLE_DEVICES" in os.environ: del os.environ["CUDA_VISIBLE_DEVICES"] shutil.rmtree(self.tmpdir) def testCallbackSteps(self): trials = [ Trial("__fake", trial_id="one"), Trial("__fake", trial_id="two") ] for t in trials: self.trial_runner.add_trial(t) self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.PG_READY) self.trial_runner.step() # Trial 1 has been started self.assertEqual(self.callback.state["trial_start"]["iteration"], 0) self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "one") # All these events haven't happened, yet self.assertTrue( all(k not in self.callback.state for k in [ "trial_restore", "trial_save", "trial_result", "trial_complete", "trial_fail", "experiment_end", ])) self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.PG_READY) self.trial_runner.step() # Iteration not increased yet self.assertEqual(self.callback.state["step_begin"]["iteration"], 1) # Iteration increased self.assertEqual(self.callback.state["step_end"]["iteration"], 2) # Second trial has been just started self.assertEqual(self.callback.state["trial_start"]["iteration"], 1) self.assertEqual(self.callback.state["trial_start"]["trial"].trial_id, "two") # Just a placeholder object ref for cp.value. cp = _TrackedCheckpoint( dir_or_data=ray.put(1), storage_mode=CheckpointStorage.PERSISTENT, metrics={TRAINING_ITERATION: 0}, ) trials[0].saving_to = cp # Let the first trial save a checkpoint self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.SAVING_RESULT, trial=trials[0], result={_ExecutorEvent.KEY_FUTURE_RESULT: "__checkpoint"}, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_save"]["iteration"], 2) self.assertEqual(self.callback.state["trial_save"]["trial"].trial_id, "one") # Let the second trial send a result result = {TRAINING_ITERATION: 1, "metric": 800, "done": False} self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.TRAINING_RESULT, trial=trials[1], result={"future_result": result}, ) self.assertTrue(not trials[1].has_reported_at_least_once) self.trial_runner.step() self.assertEqual(self.callback.state["trial_result"]["iteration"], 3) self.assertEqual(self.callback.state["trial_result"]["trial"].trial_id, "two") self.assertEqual( self.callback.state["trial_result"]["result"]["metric"], 800) self.assertEqual(trials[1].last_result["metric"], 800) # Let the second trial restore from a checkpoint trials[1].restoring_from = cp self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.RESTORING_RESULT, trial=trials[1]) self.trial_runner.step() self.assertEqual(self.callback.state["trial_restore"]["iteration"], 4) self.assertEqual( self.callback.state["trial_restore"]["trial"].trial_id, "two") # Let the second trial finish trials[1].restoring_from = None self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.TRAINING_RESULT, trial=trials[1], result={ _ExecutorEvent.KEY_FUTURE_RESULT: { TRAINING_ITERATION: 2, "metric": 900, "done": True, } }, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_complete"]["iteration"], 5) self.assertEqual( self.callback.state["trial_complete"]["trial"].trial_id, "two") # Let the first trial error self.executor.next_future_result = _ExecutorEvent( event_type=_ExecutorEventType.ERROR, trial=trials[0], result={_ExecutorEvent.KEY_EXCEPTION: Exception()}, ) self.trial_runner.step() self.assertEqual(self.callback.state["trial_fail"]["iteration"], 6) self.assertEqual(self.callback.state["trial_fail"]["trial"].trial_id, "one") def testCallbacksEndToEnd(self): def train(config): if config["do"] == "save": with tune.checkpoint_dir(0): pass tune.report(metric=1) elif config["do"] == "fail": raise RuntimeError("I am failing on purpose.") elif config["do"] == "delay": time.sleep(2) tune.report(metric=20) config = {"do": tune.grid_search(["save", "fail", "delay"])} tune.run(train, config=config, raise_on_failed_trial=False, callbacks=[self.callback]) self.assertIn("setup", self.callback.state) self.assertTrue(self.callback.state["setup"] is not None) keys = Experiment.PUBLIC_KEYS.copy() keys.add("total_num_samples") for key in keys: self.assertIn(key, self.callback.state["setup"]) # check if it was added first self.assertTrue(list(self.callback.state)[0] == "setup") self.assertEqual( self.callback.state["trial_fail"]["trial"].config["do"], "fail") self.assertEqual( self.callback.state["trial_save"]["trial"].config["do"], "save") self.assertEqual( self.callback.state["trial_result"]["trial"].config["do"], "delay") self.assertEqual( self.callback.state["trial_complete"]["trial"].config["do"], "delay") self.assertIn("experiment_end", self.callback.state) # check if it was added last self.assertTrue(list(self.callback.state)[-1] == "experiment_end") def testCallbackReordering(self): """SyncerCallback should come after LoggerCallback callbacks""" def get_positions(callbacks): first_logger_pos = None last_logger_pos = None syncer_pos = None for i, callback in enumerate(callbacks): if isinstance(callback, LoggerCallback): if first_logger_pos is None: first_logger_pos = i last_logger_pos = i elif isinstance(callback, SyncerCallback): syncer_pos = i return first_logger_pos, last_logger_pos, syncer_pos # Auto creation of loggers, no callbacks, no syncer callbacks = create_default_callbacks(None, SyncConfig(), None) first_logger_pos, last_logger_pos, syncer_pos = get_positions( callbacks) self.assertLess(last_logger_pos, syncer_pos) # Auto creation of loggers with callbacks callbacks = create_default_callbacks([Callback()], SyncConfig(), None) first_logger_pos, last_logger_pos, syncer_pos = get_positions( callbacks) self.assertLess(last_logger_pos, syncer_pos) # Auto creation of loggers with existing logger (but no CSV/JSON) callbacks = create_default_callbacks([LoggerCallback()], SyncConfig(), None) first_logger_pos, last_logger_pos, syncer_pos = get_positions( callbacks) self.assertLess(last_logger_pos, syncer_pos) # This should be reordered but preserve the regular callback order [mc1, mc2, mc3] = [Callback(), Callback(), Callback()] # Has to be legacy logger to avoid logger callback creation lc = LegacyLoggerCallback(logger_classes=DEFAULT_LOGGERS) callbacks = create_default_callbacks([mc1, mc2, lc, mc3], SyncConfig(), None) first_logger_pos, last_logger_pos, syncer_pos = get_positions( callbacks) self.assertLess(last_logger_pos, syncer_pos) self.assertLess(callbacks.index(mc1), callbacks.index(mc2)) self.assertLess(callbacks.index(mc2), callbacks.index(mc3)) self.assertLess(callbacks.index(lc), callbacks.index(mc3)) # Syncer callback is appended self.assertLess(callbacks.index(mc3), syncer_pos) @patch.object(warnings, "warn") def testCallbackSetupBackwardsCompatible(self, mocked_warning_method): class NoExperimentInSetupCallback(Callback): # Old method definition didn't take in **experiment.public_spec def setup(self): return callback = NoExperimentInSetupCallback() trial_runner = TrialRunner(callbacks=[callback]) trial_runner.setup_experiments( experiments=[Experiment("", lambda x: x)], total_num_samples=1) mocked_warning_method.assert_called_once() self.assertIn("Please update", mocked_warning_method.call_args_list[0][0][0])
def test_trial_migration(start_connected_emptyhead_cluster, tmpdir, durable): """Removing a node while cluster has space should migrate trial. The trial state should also be consistent with the checkpoint. """ cluster = start_connected_emptyhead_cluster node = cluster.add_node(num_cpus=1) cluster.wait_for_nodes() if durable: upload_dir = "file://" + str(tmpdir) syncer_callback = SyncerCallback() else: upload_dir = None syncer_callback = custom_driver_logdir_callback(str(tmpdir)) runner = TrialRunner(BasicVariantGenerator(), callbacks=[syncer_callback]) kwargs = { "stopping_criterion": { "training_iteration": 4 }, "checkpoint_freq": 2, "max_failures": 2, "remote_checkpoint_dir": upload_dir, } # Test recovery of trial that hasn't been checkpointed t = Trial("__fake", **kwargs) runner.add_trial(t) runner.step() # Start trial runner.step() # Process result assert t.last_result node2 = cluster.add_node(num_cpus=1) cluster.remove_node(node) cluster.wait_for_nodes() # TODO(ujvl): Node failure does not propagate until a step after it # actually should. This is possibly a problem with `Cluster`. runner.step() runner.step() # Recovery step # TODO(rliaw): This assertion is not critical but will not pass # because checkpoint handling is messy and should be refactored # rather than hotfixed. # assert t.last_result is None, "Trial result not restored correctly." # Process result (x2), process save, process result (x2), process save while not runner.is_finished(): runner.step() assert t.status == Trial.TERMINATED, runner.debug_string() # Test recovery of trial that has been checkpointed t2 = Trial("__fake", **kwargs) runner.add_trial(t2) # Start trial, process result (x2), process save while not t2.has_checkpoint(): runner.step() node3 = cluster.add_node(num_cpus=1) cluster.remove_node(node2) cluster.wait_for_nodes() while not runner.is_finished(): runner.step() assert t2.status == Trial.TERMINATED, runner.debug_string() # Test recovery of trial that won't be checkpointed kwargs = { "stopping_criterion": { "training_iteration": 3 }, "remote_checkpoint_dir": upload_dir, } t3 = Trial("__fake", **kwargs) runner.add_trial(t3) runner.step() # Start trial runner.step() # Process result 1 cluster.add_node(num_cpus=1) cluster.remove_node(node3) cluster.wait_for_nodes() while not runner.is_finished(): runner.step() assert t3.status == Trial.ERROR, runner.debug_string() with pytest.raises(TuneError): runner.step()
def testBurnInPeriod(self): runner = TrialRunner(trial_executor=MagicMock()) scheduler = PopulationBasedTraining( time_attr="training_iteration", metric="error", mode="min", perturbation_interval=5, burn_in_period=50, log_config=True, synch=True, ) class MockTrial(Trial): @property def checkpoint(self): return _TrackedCheckpoint( dir_or_data={"data": "None"}, storage_mode=CheckpointStorage.MEMORY, metrics={}, ) @property def status(self): return Trial.PAUSED @status.setter def status(self, status): pass trial1 = MockTrial("PPO", config=dict(num=1)) trial2 = MockTrial("PPO", config=dict(num=2)) trial3 = MockTrial("PPO", config=dict(num=3)) trial4 = MockTrial("PPO", config=dict(num=4)) runner.add_trial(trial1) runner.add_trial(trial2) runner.add_trial(trial3) runner.add_trial(trial4) scheduler.on_trial_add(runner, trial1) scheduler.on_trial_add(runner, trial2) scheduler.on_trial_add(runner, trial3) scheduler.on_trial_add(runner, trial4) # Add initial results. scheduler.on_trial_result(runner, trial1, result=dict(training_iteration=1, error=50)) scheduler.on_trial_result(runner, trial2, result=dict(training_iteration=1, error=50)) scheduler.on_trial_result(runner, trial3, result=dict(training_iteration=1, error=10)) scheduler.on_trial_result(runner, trial4, result=dict(training_iteration=1, error=100)) # Add more results. Without burn-in, this would now exploit scheduler.on_trial_result(runner, trial1, result=dict(training_iteration=30, error=50)) scheduler.on_trial_result(runner, trial2, result=dict(training_iteration=30, error=50)) scheduler.on_trial_result(runner, trial3, result=dict(training_iteration=30, error=10)) scheduler.on_trial_result(runner, trial4, result=dict(training_iteration=30, error=100)) self.assertEqual(trial4.config["num"], 4) # Add more results. Since this is after burn-in, it should now exploit scheduler.on_trial_result(runner, trial1, result=dict(training_iteration=50, error=50)) scheduler.on_trial_result(runner, trial2, result=dict(training_iteration=50, error=50)) scheduler.on_trial_result(runner, trial3, result=dict(training_iteration=50, error=10)) scheduler.on_trial_result(runner, trial4, result=dict(training_iteration=50, error=100)) self.assertEqual(trial4.config["num"], 3) # Assert that trials do not hang after `burn_in_period` self.assertTrue(all(t.status == "PAUSED" for t in runner.get_trials())) self.assertTrue(scheduler.choose_trial_to_run(runner)) # Assert that trials do not hang when a terminated trial is added trial5 = Trial("PPO", config=dict(num=5)) runner.add_trial(trial5) scheduler.on_trial_add(runner, trial5) trial5.set_status(Trial.TERMINATED) self.assertTrue(scheduler.choose_trial_to_run(runner))