예제 #1
0
def build_lineages_for_exploit(space,
                               monkeypatch,
                               trials=None,
                               elites=None,
                               additional_trials=None,
                               seed=1,
                               num=10):
    if trials is None:
        trials = space.sample(num, seed=seed)
        for i, trial in enumerate(trials):
            trial.status = "completed"
            trial._results.append(
                trial.Result(name="objective", type="objective", value=i))
    if elites is None:
        elites = space.sample(num, seed=seed + 1)
        for i, trial in enumerate(elites):
            trial.status = "completed"
            trial._results.append(
                trial.Result(name="objective", type="objective", value=i * 2))

    if additional_trials:
        trials += additional_trials

    def return_trials(*args, **kwargs):
        return trials

    def return_elites(*args, **kwargs):
        return elites

    lineages = Lineages()
    monkeypatch.setattr(lineages, "get_trials_at_depth", return_trials)
    monkeypatch.setattr(lineages, "get_elites", return_elites)

    return lineages
예제 #2
0
    def test_truncate_trial_not_in_trials(self, space, monkeypatch):
        trial = space.sample(1, seed=2)[0]

        lineages = build_lineages_for_exploit(space, monkeypatch)

        exploit = self.constructor()

        with pytest.raises(
                ValueError,
                match=
                f"Trial {trial.id} not included in list of completed trials.",
        ):
            exploit(numpy.random.RandomState(1), trial, lineages)
예제 #3
0
    def test_truncate(self, truncation_quantile, space, monkeypatch):
        """Test threshold at which is needed based on truncation_quantile"""
        # Test that trial within threshold is not replaced
        lineages = build_lineages_for_exploit(space, monkeypatch)
        trials = self.get_trials(lineages, TrialStub(objective=50))
        trials = sorted(trials, key=lambda trial: trial.objective.value)

        threshold_index = int(truncation_quantile * len(trials))

        good_trial = trials[threshold_index - 1]
        selected_trial = trials[-1]

        # Add non completed trials and shuffle the list to test it is filtered and sorted properly
        lots_of_trials = trials + space.sample(20, seed=2)
        numpy.random.shuffle(lots_of_trials)

        exploit = self.constructor(truncation_quantile=truncation_quantile,
                                   candidate_pool_ratio=0.2)

        if truncation_quantile > 0.0:

            def mocked_choice(choices, *args, **kwargs):
                raise RuntimeError("Should not be called")

            rng = RNGStub()
            rng.choice = mocked_choice

            trial = exploit._truncate(
                rng,
                good_trial,
                lots_of_trials,
            )

            assert trial is good_trial

        if truncation_quantile < 1.0:
            bad_trial = trials[threshold_index]

            def mocked_choice(choices, *args, **kwargs):
                return -1

            rng = RNGStub()
            rng.choice = mocked_choice

            trial = exploit._truncate(
                rng,
                bad_trial,
                lots_of_trials,
            )

            assert trial is selected_trial
예제 #4
0
    def test_truncate_non_completed_trials(self, space, monkeypatch):
        trial = space.sample(1, seed=2)[0]

        lineages = build_lineages_for_exploit(space,
                                              monkeypatch,
                                              additional_trials=[trial])

        assert trial in lineages.get_trials_at_depth(trial)

        exploit = self.constructor()

        with pytest.raises(
                ValueError,
                match=
                f"Trial {trial.id} not included in list of completed trials.",
        ):
            exploit(numpy.random.RandomState(1), trial, lineages)
예제 #5
0
    def test_truncate_valid_choice(self, candidate_pool_ratio, space,
                                   monkeypatch):
        """Test the pool of available trials based on candidate_pool_ratio"""
        lineages = build_lineages_for_exploit(space, monkeypatch)
        trials = self.get_trials(lineages, TrialStub(objective=50))
        trials = sorted(trials, key=lambda trial: trial.objective.value)

        num_completed_trials = len(trials)
        valid_choices = numpy.arange(
            int(candidate_pool_ratio * num_completed_trials)).tolist()
        selected_trial = trials[valid_choices[-1]]

        def mocked_choice(choices, *args, **kwargs):
            assert choices.tolist() == valid_choices
            return valid_choices[-1]

        rng = RNGStub()
        rng.choice = mocked_choice

        completed_trial_index = numpy.random.choice(range(len(trials)))
        completed_trial = trials[completed_trial_index]

        # Add non completed trials and shuffle the list to test it is filtered and sorted properly
        trials += space.sample(20, seed=2)
        numpy.random.shuffle(trials)

        exploit = self.constructor(truncation_quantile=0,
                                   candidate_pool_ratio=candidate_pool_ratio)

        trial = exploit._truncate(
            rng,
            completed_trial,
            trials,
        )

        assert trial is selected_trial