def build_lineages_for_exploit(space, monkeypatch, trials=None, elites=None, additional_trials=None, seed=1, num=10): if trials is None: trials = space.sample(num, seed=seed) for i, trial in enumerate(trials): trial.status = "completed" trial._results.append( trial.Result(name="objective", type="objective", value=i)) if elites is None: elites = space.sample(num, seed=seed + 1) for i, trial in enumerate(elites): trial.status = "completed" trial._results.append( trial.Result(name="objective", type="objective", value=i * 2)) if additional_trials: trials += additional_trials def return_trials(*args, **kwargs): return trials def return_elites(*args, **kwargs): return elites lineages = Lineages() monkeypatch.setattr(lineages, "get_trials_at_depth", return_trials) monkeypatch.setattr(lineages, "get_elites", return_elites) return lineages
def test_truncate_trial_not_in_trials(self, space, monkeypatch): trial = space.sample(1, seed=2)[0] lineages = build_lineages_for_exploit(space, monkeypatch) exploit = self.constructor() with pytest.raises( ValueError, match= f"Trial {trial.id} not included in list of completed trials.", ): exploit(numpy.random.RandomState(1), trial, lineages)
def test_truncate(self, truncation_quantile, space, monkeypatch): """Test threshold at which is needed based on truncation_quantile""" # Test that trial within threshold is not replaced lineages = build_lineages_for_exploit(space, monkeypatch) trials = self.get_trials(lineages, TrialStub(objective=50)) trials = sorted(trials, key=lambda trial: trial.objective.value) threshold_index = int(truncation_quantile * len(trials)) good_trial = trials[threshold_index - 1] selected_trial = trials[-1] # Add non completed trials and shuffle the list to test it is filtered and sorted properly lots_of_trials = trials + space.sample(20, seed=2) numpy.random.shuffle(lots_of_trials) exploit = self.constructor(truncation_quantile=truncation_quantile, candidate_pool_ratio=0.2) if truncation_quantile > 0.0: def mocked_choice(choices, *args, **kwargs): raise RuntimeError("Should not be called") rng = RNGStub() rng.choice = mocked_choice trial = exploit._truncate( rng, good_trial, lots_of_trials, ) assert trial is good_trial if truncation_quantile < 1.0: bad_trial = trials[threshold_index] def mocked_choice(choices, *args, **kwargs): return -1 rng = RNGStub() rng.choice = mocked_choice trial = exploit._truncate( rng, bad_trial, lots_of_trials, ) assert trial is selected_trial
def test_truncate_non_completed_trials(self, space, monkeypatch): trial = space.sample(1, seed=2)[0] lineages = build_lineages_for_exploit(space, monkeypatch, additional_trials=[trial]) assert trial in lineages.get_trials_at_depth(trial) exploit = self.constructor() with pytest.raises( ValueError, match= f"Trial {trial.id} not included in list of completed trials.", ): exploit(numpy.random.RandomState(1), trial, lineages)
def test_truncate_valid_choice(self, candidate_pool_ratio, space, monkeypatch): """Test the pool of available trials based on candidate_pool_ratio""" lineages = build_lineages_for_exploit(space, monkeypatch) trials = self.get_trials(lineages, TrialStub(objective=50)) trials = sorted(trials, key=lambda trial: trial.objective.value) num_completed_trials = len(trials) valid_choices = numpy.arange( int(candidate_pool_ratio * num_completed_trials)).tolist() selected_trial = trials[valid_choices[-1]] def mocked_choice(choices, *args, **kwargs): assert choices.tolist() == valid_choices return valid_choices[-1] rng = RNGStub() rng.choice = mocked_choice completed_trial_index = numpy.random.choice(range(len(trials))) completed_trial = trials[completed_trial_index] # Add non completed trials and shuffle the list to test it is filtered and sorted properly trials += space.sample(20, seed=2) numpy.random.shuffle(trials) exploit = self.constructor(truncation_quantile=0, candidate_pool_ratio=candidate_pool_ratio) trial = exploit._truncate( rng, completed_trial, trials, ) assert trial is selected_trial