def testSavePauseResumeErrorRestore(self): """Tests that pause checkpoint does not replace restore checkpoint.""" trial = Trial("__fake") self._simulate_starting_trial(trial) self._simulate_getting_result(trial) # Save self._simulate_saving(trial) # Train self.trial_executor.continue_training(trial) self._simulate_getting_result(trial) # Pause self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY) # Resume self._simulate_starting_trial(trial) # Error trial.set_status(Trial.ERROR) # Restore self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status)
def testSavePauseResumeErrorRestore(self): """Tests that pause checkpoint does not replace restore checkpoint.""" trial = Trial("__fake") self.trial_executor.start_trial(trial) trial.last_result = self.trial_executor.fetch_result(trial)[-1] # Save checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT) self.assertEqual(Trial.RUNNING, trial.status) self.assertEqual(checkpoint.storage, Checkpoint.PERSISTENT) # Process save result (simulates trial runner) self.process_trial_save(trial) # Train self.trial_executor.continue_training(trial) trial.last_result = self.trial_executor.fetch_result(trial)[-1] # Pause self.trial_executor.pause_trial(trial) self.assertEqual(Trial.PAUSED, trial.status) self.assertEqual(trial.checkpoint.storage, Checkpoint.MEMORY) # Resume self.trial_executor.start_trial(trial) self.assertEqual(Trial.RUNNING, trial.status) # Error trial.set_status(Trial.ERROR) # Restore self.trial_executor.restore(trial) self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status)
def set_status(self, trial: Trial, status: str) -> None: """Sets status and checkpoints metadata if needed. Only checkpoints metadata if trial status is a terminal condition. PENDING, PAUSED, and RUNNING switches have checkpoints taken care of in the TrialRunner. Args: trial (Trial): Trial to checkpoint. status (Trial.status): Status to set trial to. """ if trial.status == status: logger.debug("Trial %s: Status %s unchanged.", trial, trial.status) else: logger.debug("Trial %s: Changing status from %s to %s.", trial, trial.status, status) trial.set_status(status) if status in [Trial.TERMINATED, Trial.ERROR]: self._trials_to_cache.add(trial)
def testBurnInPeriod(self): runner = TrialRunner(trial_executor=MagicMock()) scheduler = PopulationBasedTraining( time_attr="training_iteration", metric="error", mode="min", perturbation_interval=5, burn_in_period=50, log_config=True, synch=True, ) class MockTrial(Trial): @property def checkpoint(self): return _TrackedCheckpoint( dir_or_data={"data": "None"}, storage_mode=CheckpointStorage.MEMORY, metrics={}, ) @property def status(self): return Trial.PAUSED @status.setter def status(self, status): pass trial1 = MockTrial("PPO", config=dict(num=1)) trial2 = MockTrial("PPO", config=dict(num=2)) trial3 = MockTrial("PPO", config=dict(num=3)) trial4 = MockTrial("PPO", config=dict(num=4)) runner.add_trial(trial1) runner.add_trial(trial2) runner.add_trial(trial3) runner.add_trial(trial4) scheduler.on_trial_add(runner, trial1) scheduler.on_trial_add(runner, trial2) scheduler.on_trial_add(runner, trial3) scheduler.on_trial_add(runner, trial4) # Add initial results. scheduler.on_trial_result( runner, trial1, result=dict(training_iteration=1, error=50) ) scheduler.on_trial_result( runner, trial2, result=dict(training_iteration=1, error=50) ) scheduler.on_trial_result( runner, trial3, result=dict(training_iteration=1, error=10) ) scheduler.on_trial_result( runner, trial4, result=dict(training_iteration=1, error=100) ) # Add more results. Without burn-in, this would now exploit scheduler.on_trial_result( runner, trial1, result=dict(training_iteration=30, error=50) ) scheduler.on_trial_result( runner, trial2, result=dict(training_iteration=30, error=50) ) scheduler.on_trial_result( runner, trial3, result=dict(training_iteration=30, error=10) ) scheduler.on_trial_result( runner, trial4, result=dict(training_iteration=30, error=100) ) self.assertEqual(trial4.config["num"], 4) # Add more results. Since this is after burn-in, it should now exploit scheduler.on_trial_result( runner, trial1, result=dict(training_iteration=50, error=50) ) scheduler.on_trial_result( runner, trial2, result=dict(training_iteration=50, error=50) ) scheduler.on_trial_result( runner, trial3, result=dict(training_iteration=50, error=10) ) scheduler.on_trial_result( runner, trial4, result=dict(training_iteration=50, error=100) ) self.assertEqual(trial4.config["num"], 3) # Assert that trials do not hang after `burn_in_period` self.assertTrue(all(t.status == "PAUSED" for t in runner.get_trials())) self.assertTrue(scheduler.choose_trial_to_run(runner)) # Assert that trials do not hang when a terminated trial is added trial5 = Trial("PPO", config=dict(num=5)) runner.add_trial(trial5) scheduler.on_trial_add(runner, trial5) trial5.set_status(Trial.TERMINATED) self.assertTrue(scheduler.choose_trial_to_run(runner))