コード例 #1
0
    def test_truncate(self, truncation_quantile, space, monkeypatch):
        """Test threshold at which is needed based on truncation_quantile"""
        # Test that trial within threshold is not replaced
        lineages = build_lineages_for_exploit(space, monkeypatch)
        trials = self.get_trials(lineages, TrialStub(objective=50))
        trials = sorted(trials, key=lambda trial: trial.objective.value)

        threshold_index = int(truncation_quantile * len(trials))

        good_trial = trials[threshold_index - 1]
        selected_trial = trials[-1]

        # Add non completed trials and shuffle the list to test it is filtered and sorted properly
        lots_of_trials = trials + space.sample(20, seed=2)
        numpy.random.shuffle(lots_of_trials)

        exploit = self.constructor(truncation_quantile=truncation_quantile,
                                   candidate_pool_ratio=0.2)

        if truncation_quantile > 0.0:

            def mocked_choice(choices, *args, **kwargs):
                raise RuntimeError("Should not be called")

            rng = RNGStub()
            rng.choice = mocked_choice

            trial = exploit._truncate(
                rng,
                good_trial,
                lots_of_trials,
            )

            assert trial is good_trial

        if truncation_quantile < 1.0:
            bad_trial = trials[threshold_index]

            def mocked_choice(choices, *args, **kwargs):
                return -1

            rng = RNGStub()
            rng.choice = mocked_choice

            trial = exploit._truncate(
                rng,
                bad_trial,
                lots_of_trials,
            )

            assert trial is selected_trial
コード例 #2
0
    def test_perturb_cat(self):
        explore = PerturbExplore()
        rng = RNGStub()
        rng.randint = lambda low, high, size: [1]
        rng.choice = lambda choices: choices[0]

        dim = Categorical("name", ["one", "two", 3, 4.0])
        assert explore.perturb_cat(rng, "whatever", dim) in dim
コード例 #3
0
    def test_perturb(self, space):
        explore = PerturbExplore()
        rng = RNGStub()
        rng.randint = lambda low, high, size: [1]
        rng.random = lambda: 1.0
        rng.normal = lambda mean, variance: 0.0
        rng.choice = lambda choices: choices[0]

        params = {"x": 1.0, "y": 2, "z": 0, "f": 10}
        new_params = explore(rng, space, params)
        for key in space.keys():
            assert new_params[key] in space[key]
コード例 #4
0
    def test_perturb_hierarchical_params(self, hspace):
        explore = PerturbExplore()
        rng = RNGStub()
        rng.randint = lambda low, high, size: [1]
        rng.random = lambda: 1.0
        rng.normal = lambda mean, variance: 0.0
        rng.choice = lambda choices: choices[0]

        params = {"numerical": {"x": 1.0, "y": 2, "f": 10}, "z": 0}
        new_params = explore(rng, hspace, params)
        assert "numerical" in new_params
        assert "x" in new_params["numerical"]
        for key in hspace.keys():
            assert flatten(new_params)[key] in hspace[key]
コード例 #5
0
    def test_truncate_valid_choice(self, candidate_pool_ratio, space,
                                   monkeypatch):
        """Test the pool of available trials based on candidate_pool_ratio"""
        lineages = build_lineages_for_exploit(space, monkeypatch)
        trials = self.get_trials(lineages, TrialStub(objective=50))
        trials = sorted(trials, key=lambda trial: trial.objective.value)

        num_completed_trials = len(trials)
        valid_choices = numpy.arange(
            int(candidate_pool_ratio * num_completed_trials)).tolist()
        selected_trial = trials[valid_choices[-1]]

        def mocked_choice(choices, *args, **kwargs):
            assert choices.tolist() == valid_choices
            return valid_choices[-1]

        rng = RNGStub()
        rng.choice = mocked_choice

        completed_trial_index = numpy.random.choice(range(len(trials)))
        completed_trial = trials[completed_trial_index]

        # Add non completed trials and shuffle the list to test it is filtered and sorted properly
        trials += space.sample(20, seed=2)
        numpy.random.shuffle(trials)

        exploit = self.constructor(truncation_quantile=0,
                                   candidate_pool_ratio=candidate_pool_ratio)

        trial = exploit._truncate(
            rng,
            completed_trial,
            trials,
        )

        assert trial is selected_trial