コード例 #1
0
    def test_perturb_int_no_duplicate_below(self):
        explore = PerturbExplore(factor=0.75)

        rng = RNGStub()
        rng.random = lambda: 1.0

        assert explore.perturb_int(rng, 1, (0, 10)) == 0
コード例 #2
0
    def test_perturb_int_duplicate_equal(self):
        explore = PerturbExplore(factor=1.0)

        rng = RNGStub()
        rng.random = lambda: 1.0

        assert explore.perturb_int(rng, 1, (0, 10)) == 1
コード例 #3
0
    def test_perturb_real_volatility_above(self, volatility):
        explore = PerturbExplore(factor=1.0, volatility=volatility)

        rng = RNGStub()
        rng.random = lambda: 1.0
        rng.normal = lambda mean, variance: variance

        assert explore.perturb_real(rng, 3.0, (1.0, 2.0)) == 2.0 - volatility
コード例 #4
0
    def test_perturb_cat(self):
        explore = PerturbExplore()
        rng = RNGStub()
        rng.randint = lambda low, high, size: [1]
        rng.choice = lambda choices: choices[0]

        dim = Categorical("name", ["one", "two", 3, 4.0])
        assert explore.perturb_cat(rng, "whatever", dim) in dim
コード例 #5
0
    def test_perturb_real_factor(self, factor):
        explore = PerturbExplore(factor=factor)

        rng = RNGStub()
        rng.random = lambda: 1.0

        assert explore.perturb_real(rng, 1.0, (0.1, 2.0)) == factor

        rng.random = lambda: 0.0

        assert explore.perturb_real(rng, 1.0, (0.1, 2.0)) == 1.0 / factor
コード例 #6
0
    def test_perturb_real_above_interval_cap(self):
        explore = PerturbExplore(factor=1.0, volatility=0)

        rng = RNGStub()
        rng.random = lambda: 1.0
        rng.normal = lambda mean, variance: variance

        assert explore.perturb_real(rng, 3.0, (1.0, 2.0)) == 2.0

        explore.volatility = 1000

        assert explore.perturb_real(rng, 3.0, (1.0, 2.0)) == 1.0
コード例 #7
0
    def test_perturb_int_factor(self, factor):
        explore = PerturbExplore(factor=factor)

        rng = RNGStub()
        rng.random = lambda: 1.0

        assert explore.perturb_int(rng, 5,
                                   (0, 10)) == int(numpy.round(5 * factor))

        rng.random = lambda: 0.0

        assert explore.perturb_int(rng, 5,
                                   (0, 10)) == int(numpy.round(5 / factor))
コード例 #8
0
    def test_resample_probability(self, space):
        explore = ResampleExplore(probability=0.5)

        rng = RNGStub()
        rng.randint = lambda low, high, size: [1]
        rng.random = lambda: 0.5

        params = {"x": 1.0, "y": 2, "z": 0, "f": 10}

        assert explore(rng, space, params) is params

        rng.random = lambda: 0.4

        assert explore(rng, space, params) is not params
コード例 #9
0
 def test_exploit_otherwise_next(self):
     for i in range(4):
         exploit = PipelineExploit([
             dict(of_type="exploitstub",
                  rval=None if j < i else i,
                  some="args") for j in range(4)
         ])
         assert exploit(RNGStub(), TrialStub(), None) == i
コード例 #10
0
    def test_perturb_with_invalid_dim(self, space, monkeypatch):
        explore = PerturbExplore()

        monkeypatch.setattr(Dimension, "type", "type_that_dont_exist")

        with pytest.raises(
                ValueError,
                match="Unsupported dimension type type_that_dont_exist"):
            explore(RNGStub(), space, {"x": 1.0, "y": 2, "z": 0, "f": 10})
コード例 #11
0
    def test_fetch_trials_properly(self, space, monkeypatch):

        lineages = build_lineages_for_exploit(space, monkeypatch)
        exploit = self.constructor()

        def test_truncate_args(rng, trial, trials):
            assert trials == self.get_trials(lineages, trial)

        monkeypatch.setattr(exploit, "_truncate", test_truncate_args)

        exploit(RNGStub(), TrialStub(id="selected-trial"), lineages)
コード例 #12
0
    def test_truncate_valid_choice(self, candidate_pool_ratio, space,
                                   monkeypatch):
        """Test the pool of available trials based on candidate_pool_ratio"""
        lineages = build_lineages_for_exploit(space, monkeypatch)
        trials = self.get_trials(lineages, TrialStub(objective=50))
        trials = sorted(trials, key=lambda trial: trial.objective.value)

        num_completed_trials = len(trials)
        valid_choices = numpy.arange(
            int(candidate_pool_ratio * num_completed_trials)).tolist()
        selected_trial = trials[valid_choices[-1]]

        def mocked_choice(choices, *args, **kwargs):
            assert choices.tolist() == valid_choices
            return valid_choices[-1]

        rng = RNGStub()
        rng.choice = mocked_choice

        completed_trial_index = numpy.random.choice(range(len(trials)))
        completed_trial = trials[completed_trial_index]

        # Add non completed trials and shuffle the list to test it is filtered and sorted properly
        trials += space.sample(20, seed=2)
        numpy.random.shuffle(trials)

        exploit = self.constructor(truncation_quantile=0,
                                   candidate_pool_ratio=candidate_pool_ratio)

        trial = exploit._truncate(
            rng,
            completed_trial,
            trials,
        )

        assert trial is selected_trial
コード例 #13
0
    def test_perturb(self, space):
        explore = PerturbExplore()
        rng = RNGStub()
        rng.randint = lambda low, high, size: [1]
        rng.random = lambda: 1.0
        rng.normal = lambda mean, variance: 0.0
        rng.choice = lambda choices: choices[0]

        params = {"x": 1.0, "y": 2, "z": 0, "f": 10}
        new_params = explore(rng, space, params)
        for key in space.keys():
            assert new_params[key] in space[key]
コード例 #14
0
    def test_truncate(self, truncation_quantile, space, monkeypatch):
        """Test threshold at which is needed based on truncation_quantile"""
        # Test that trial within threshold is not replaced
        lineages = build_lineages_for_exploit(space, monkeypatch)
        trials = self.get_trials(lineages, TrialStub(objective=50))
        trials = sorted(trials, key=lambda trial: trial.objective.value)

        threshold_index = int(truncation_quantile * len(trials))

        good_trial = trials[threshold_index - 1]
        selected_trial = trials[-1]

        # Add non completed trials and shuffle the list to test it is filtered and sorted properly
        lots_of_trials = trials + space.sample(20, seed=2)
        numpy.random.shuffle(lots_of_trials)

        exploit = self.constructor(truncation_quantile=truncation_quantile,
                                   candidate_pool_ratio=0.2)

        if truncation_quantile > 0.0:

            def mocked_choice(choices, *args, **kwargs):
                raise RuntimeError("Should not be called")

            rng = RNGStub()
            rng.choice = mocked_choice

            trial = exploit._truncate(
                rng,
                good_trial,
                lots_of_trials,
            )

            assert trial is good_trial

        if truncation_quantile < 1.0:
            bad_trial = trials[threshold_index]

            def mocked_choice(choices, *args, **kwargs):
                return -1

            rng = RNGStub()
            rng.choice = mocked_choice

            trial = exploit._truncate(
                rng,
                bad_trial,
                lots_of_trials,
            )

            assert trial is selected_trial
コード例 #15
0
    def test_perturb_int_no_out_of_bounds(self):
        explore = PerturbExplore(factor=0.75, volatility=0)

        rng = RNGStub()

        rng.random = lambda: 1.0
        rng.normal = lambda mean, variance: variance

        assert explore.perturb_int(rng, 0, (0, 10)) == 0

        rng.random = lambda: 0.0
        rng.normal = lambda mean, variance: variance

        assert explore.perturb_int(rng, 10, (0, 10)) == 10
コード例 #16
0
    def test_perturb_hierarchical_params(self, hspace):
        explore = PerturbExplore()
        rng = RNGStub()
        rng.randint = lambda low, high, size: [1]
        rng.random = lambda: 1.0
        rng.normal = lambda mean, variance: 0.0
        rng.choice = lambda choices: choices[0]

        params = {"numerical": {"x": 1.0, "y": 2, "f": 10}, "z": 0}
        new_params = explore(rng, hspace, params)
        assert "numerical" in new_params
        assert "x" in new_params["numerical"]
        for key in hspace.keys():
            assert flatten(new_params)[key] in hspace[key]
コード例 #17
0
    def test_truncate_not_enough_trials(self, space, monkeypatch):
        lineages = build_lineages_for_exploit(space, monkeypatch, num=4)

        exploit = self.constructor(min_forking_population=5)

        assert exploit(RNGStub(), TrialStub(), lineages) is None
コード例 #18
0
 def test_no_exploit(self):
     trial = TrialStub()
     assert PipelineExploit([])(RNGStub(), trial, None) is trial
コード例 #19
0
 def test_no_explore(self):
     params = object()
     assert PipelineExplore([])(RNGStub(), None, params) is params