Exemple #1
0
    def testSavePauseResumeErrorRestore(self):
        """Tests that pause checkpoint does not replace restore checkpoint."""
        trial = Trial("__fake")
        self._simulate_starting_trial(trial)

        self._simulate_getting_result(trial)

        # Save
        self._simulate_saving(trial)

        # Train
        self.trial_executor.continue_training(trial)
        self._simulate_getting_result(trial)

        # Pause
        self.trial_executor.pause_trial(trial)
        self.assertEqual(Trial.PAUSED, trial.status)
        self.assertEqual(trial.checkpoint.storage_mode, CheckpointStorage.MEMORY)

        # Resume
        self._simulate_starting_trial(trial)

        # Error
        trial.set_status(Trial.ERROR)

        # Restore
        self.trial_executor.restore(trial)

        self.trial_executor.stop_trial(trial)
        self.assertEqual(Trial.TERMINATED, trial.status)
Exemple #2
0
    def testBurnInPeriod(self):
        runner = TrialRunner(trial_executor=MagicMock())

        scheduler = PopulationBasedTraining(
            time_attr="training_iteration",
            metric="error",
            mode="min",
            perturbation_interval=5,
            burn_in_period=50,
            log_config=True,
            synch=True,
        )

        class MockTrial(Trial):
            @property
            def checkpoint(self):
                return _TrackedCheckpoint(
                    dir_or_data={"data": "None"},
                    storage_mode=CheckpointStorage.MEMORY,
                    metrics={},
                )

            @property
            def status(self):
                return Trial.PAUSED

            @status.setter
            def status(self, status):
                pass

        trial1 = MockTrial("PPO", config=dict(num=1))
        trial2 = MockTrial("PPO", config=dict(num=2))
        trial3 = MockTrial("PPO", config=dict(num=3))
        trial4 = MockTrial("PPO", config=dict(num=4))

        runner.add_trial(trial1)
        runner.add_trial(trial2)
        runner.add_trial(trial3)
        runner.add_trial(trial4)

        scheduler.on_trial_add(runner, trial1)
        scheduler.on_trial_add(runner, trial2)
        scheduler.on_trial_add(runner, trial3)
        scheduler.on_trial_add(runner, trial4)

        # Add initial results.
        scheduler.on_trial_result(runner,
                                  trial1,
                                  result=dict(training_iteration=1, error=50))
        scheduler.on_trial_result(runner,
                                  trial2,
                                  result=dict(training_iteration=1, error=50))
        scheduler.on_trial_result(runner,
                                  trial3,
                                  result=dict(training_iteration=1, error=10))
        scheduler.on_trial_result(runner,
                                  trial4,
                                  result=dict(training_iteration=1, error=100))

        # Add more results. Without burn-in, this would now exploit
        scheduler.on_trial_result(runner,
                                  trial1,
                                  result=dict(training_iteration=30, error=50))
        scheduler.on_trial_result(runner,
                                  trial2,
                                  result=dict(training_iteration=30, error=50))
        scheduler.on_trial_result(runner,
                                  trial3,
                                  result=dict(training_iteration=30, error=10))
        scheduler.on_trial_result(runner,
                                  trial4,
                                  result=dict(training_iteration=30,
                                              error=100))

        self.assertEqual(trial4.config["num"], 4)

        # Add more results. Since this is after burn-in, it should now exploit
        scheduler.on_trial_result(runner,
                                  trial1,
                                  result=dict(training_iteration=50, error=50))
        scheduler.on_trial_result(runner,
                                  trial2,
                                  result=dict(training_iteration=50, error=50))
        scheduler.on_trial_result(runner,
                                  trial3,
                                  result=dict(training_iteration=50, error=10))
        scheduler.on_trial_result(runner,
                                  trial4,
                                  result=dict(training_iteration=50,
                                              error=100))

        self.assertEqual(trial4.config["num"], 3)

        # Assert that trials do not hang after `burn_in_period`
        self.assertTrue(all(t.status == "PAUSED" for t in runner.get_trials()))
        self.assertTrue(scheduler.choose_trial_to_run(runner))

        # Assert that trials do not hang when a terminated trial is added
        trial5 = Trial("PPO", config=dict(num=5))
        runner.add_trial(trial5)
        scheduler.on_trial_add(runner, trial5)
        trial5.set_status(Trial.TERMINATED)
        self.assertTrue(scheduler.choose_trial_to_run(runner))