def testSavePauseResumeErrorRestore(self): """Tests that pause checkpoint does not replace restore checkpoint.""" trial = Trial("__fake") self._simulate_starting_trial(trial) self._simulate_getting_result(trial) # Save self._simulate_saving(trial) # Train self.trial_executor.continue_training(trial) self._simulate_getting_result(trial) # Pause self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.assertEqual(trial.checkpoint.storage_mode, CheckpointStorage.MEMORY) # Resume self._simulate_starting_trial(trial) # Error trial.set_status(Trial.ERROR) # Restore self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status)
def testBurnInPeriod(self): runner = TrialRunner(trial_executor=MagicMock()) scheduler = PopulationBasedTraining( time_attr="training_iteration", metric="error", mode="min", perturbation_interval=5, burn_in_period=50, log_config=True, synch=True, ) class MockTrial(Trial): @property def checkpoint(self): return _TrackedCheckpoint( dir_or_data={"data": "None"}, storage_mode=CheckpointStorage.MEMORY, metrics={}, ) @property def status(self): return Trial.PAUSED @status.setter def status(self, status): pass trial1 = MockTrial("PPO", config=dict(num=1)) trial2 = MockTrial("PPO", config=dict(num=2)) trial3 = MockTrial("PPO", config=dict(num=3)) trial4 = MockTrial("PPO", config=dict(num=4)) runner.add_trial(trial1) runner.add_trial(trial2) runner.add_trial(trial3) runner.add_trial(trial4) scheduler.on_trial_add(runner, trial1) scheduler.on_trial_add(runner, trial2) scheduler.on_trial_add(runner, trial3) scheduler.on_trial_add(runner, trial4) # Add initial results. scheduler.on_trial_result(runner, trial1, result=dict(training_iteration=1, error=50)) scheduler.on_trial_result(runner, trial2, result=dict(training_iteration=1, error=50)) scheduler.on_trial_result(runner, trial3, result=dict(training_iteration=1, error=10)) scheduler.on_trial_result(runner, trial4, result=dict(training_iteration=1, error=100)) # Add more results. Without burn-in, this would now exploit scheduler.on_trial_result(runner, trial1, result=dict(training_iteration=30, error=50)) scheduler.on_trial_result(runner, trial2, result=dict(training_iteration=30, error=50)) scheduler.on_trial_result(runner, trial3, result=dict(training_iteration=30, error=10)) scheduler.on_trial_result(runner, trial4, result=dict(training_iteration=30, error=100)) self.assertEqual(trial4.config["num"], 4) # Add more results. Since this is after burn-in, it should now exploit scheduler.on_trial_result(runner, trial1, result=dict(training_iteration=50, error=50)) scheduler.on_trial_result(runner, trial2, result=dict(training_iteration=50, error=50)) scheduler.on_trial_result(runner, trial3, result=dict(training_iteration=50, error=10)) scheduler.on_trial_result(runner, trial4, result=dict(training_iteration=50, error=100)) self.assertEqual(trial4.config["num"], 3) # Assert that trials do not hang after `burn_in_period` self.assertTrue(all(t.status == "PAUSED" for t in runner.get_trials())) self.assertTrue(scheduler.choose_trial_to_run(runner)) # Assert that trials do not hang when a terminated trial is added trial5 = Trial("PPO", config=dict(num=5)) runner.add_trial(trial5) scheduler.on_trial_add(runner, trial5) trial5.set_status(Trial.TERMINATED) self.assertTrue(scheduler.choose_trial_to_run(runner))