예제 #1
0
    def test_register_bracket_multi_fidelity(self, space: Space):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 1
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        force_observe(hyperband, trial)
        assert hyperband.brackets is not None
        bracket = hyperband.brackets[0]

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params

        fidelity = 3
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        force_observe(hyperband, trial)

        assert len(bracket.rungs[1])
        assert trial_id in bracket.rungs[1]["results"]
        assert bracket.rungs[0]["results"][trial_id][1].params != trial.params
        assert bracket.rungs[1]["results"][trial_id][0] == 0.0
        assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
예제 #2
0
    def test_get_id(self, space):
        """Test valid id of points"""
        hyperband = Hyperband(space)

        assert hyperband.get_id(['whatever', 1]) == hyperband.get_id(['is here', 1])
        assert hyperband.get_id(['whatever', 1]) != hyperband.get_id(['is here', 2])
        assert hyperband.get_id(['whatever', 1], ignore_fidelity=False) != \
            hyperband.get_id(['is here', 1], ignore_fidelity=False)
        assert hyperband.get_id(['whatever', 1], ignore_fidelity=False) != \
            hyperband.get_id(['is here', 2], ignore_fidelity=False)
        assert hyperband.get_id(['same', 1], ignore_fidelity=False) == \
            hyperband.get_id(['same', 1], ignore_fidelity=False)
        assert hyperband.get_id(['same', 1], ignore_fidelity=False) != \
            hyperband.get_id(['same', 1])
예제 #3
0
    def test_suggest_duplicates_between_calls(self, monkeypatch,
                                              hyperband: Hyperband,
                                              bracket: HyperbandBracket):
        """Test that same trials are not allowed in different suggest call of
        the same hyperband execution.
        """
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband

        fidelity = 1

        duplicate_trial = create_trial_for_hb((fidelity, 0.0))
        new_trial = create_trial_for_hb((fidelity, 0.5))

        duplicate_id = hyperband.get_id(duplicate_trial, ignore_fidelity=True)
        bracket.rungs[0]["results"] = {
            duplicate_id: (0.0, duplicate_trial)
        }  # type: ignore

        hyperband.trial_to_brackets[duplicate_id] = 0

        trials = [duplicate_trial, new_trial]

        mock_samples(
            hyperband,
            trials +
            [create_trial_for_hb((fidelity, i)) for i in range(10 - 2)],
        )
        trials = hyperband.suggest(100)
        assert trials is not None
        assert trials[0].params == new_trial.params
예제 #4
0
    def test_suggest_duplicates_one_call(self, monkeypatch,
                                         hyperband: Hyperband,
                                         bracket: HyperbandBracket):
        """Test that same points are not allowed in the same suggest call ofxs
        the same hyperband execution.
        """
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband

        zhe_point = list(
            map(create_trial_for_hb, [(1, 0.0), (1, 1.0), (1, 1.0), (1, 2.0)]))

        mock_samples(hyperband, zhe_point * 2)
        zhe_samples = hyperband.suggest(100)
        assert zhe_samples is not None
        assert zhe_samples[0].params["lr"] == 0.0
        assert zhe_samples[1].params["lr"] == 1.0
        assert zhe_samples[2].params["lr"] == 2.0

        # zhe_point =
        mock_samples(
            hyperband,
            list(
                map(
                    create_trial_for_hb,
                    [
                        (3, 0.0),
                        (3, 1.0),
                        (3, 1.0),
                        (3, 2.0),
                        (3, 5.0),
                        (3, 4.0),
                    ],
                )),
        )
        hyperband.trial_to_brackets[hyperband.get_id(create_trial_for_hb(
            (1, 0.0)),
                                                     ignore_fidelity=True)] = 0
        hyperband.trial_to_brackets[hyperband.get_id(create_trial_for_hb(
            (1, 0.0)),
                                                     ignore_fidelity=True)] = 0
        zhe_samples = hyperband.suggest(100)
        assert zhe_samples is not None
        assert zhe_samples[0].params["lr"] == 5.0
        assert zhe_samples[1].params["lr"] == 4.0
예제 #5
0
    def test_get_id_multidim(self):
        """Test valid id for points with dim of shape > 1"""
        space = Space()
        space.register(Fidelity('epoch', 1, 9, 3))
        space.register(Real('lr', 'uniform', 0, 1, shape=2))

        hyperband = Hyperband(space)

        assert hyperband.get_id(['whatever', [1, 1]]) == hyperband.get_id(['is here', [1, 1]])
        assert hyperband.get_id(['whatever', [1, 1]]) != hyperband.get_id(['is here', [2, 2]])
        assert hyperband.get_id(['whatever', [1, 1]], ignore_fidelity=False) != \
            hyperband.get_id(['is here', [1, 1]], ignore_fidelity=False)
        assert hyperband.get_id(['whatever', [1, 1]], ignore_fidelity=False) != \
            hyperband.get_id(['is here', [2, 2]], ignore_fidelity=False)
        assert hyperband.get_id(['same', [1, 1]], ignore_fidelity=False) == \
            hyperband.get_id(['same', [1, 1]], ignore_fidelity=False)
        assert hyperband.get_id(['same', [1, 1]], ignore_fidelity=False) != \
            hyperband.get_id(['same', [1, 1]])
예제 #6
0
    def test_get_id(self, space):
        """Test valid id of points"""
        hyperband = Hyperband(space)

        assert hyperband.get_id(["whatever", 1]) == hyperband.get_id(["is here", 1])
        assert hyperband.get_id(["whatever", 1]) != hyperband.get_id(["is here", 2])
        assert hyperband.get_id(
            ["whatever", 1], ignore_fidelity=False
        ) != hyperband.get_id(["is here", 1], ignore_fidelity=False)
        assert hyperband.get_id(
            ["whatever", 1], ignore_fidelity=False
        ) != hyperband.get_id(["is here", 2], ignore_fidelity=False)
        assert hyperband.get_id(["same", 1], ignore_fidelity=False) == hyperband.get_id(
            ["same", 1], ignore_fidelity=False
        )
        assert hyperband.get_id(["same", 1], ignore_fidelity=False) != hyperband.get_id(
            ["same", 1]
        )
예제 #7
0
    def test_update_rungs_return_candidate(self, hyperband: Hyperband,
                                           bracket: HyperbandBracket,
                                           rung_1: RungDict):
        """Check if a valid modified candidate is returned by update_rungs."""
        assert bracket.owner is hyperband
        bracket.rungs[1] = rung_1
        trial = create_trial_for_hb((3, 0.0), 0.0)

        candidates = bracket.promote(1)

        trial_id = hyperband.get_id(trial, ignore_fidelity=True)
        assert trial_id in bracket.rungs[1]["results"]
        assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
        assert candidates[0].params["epoch"] == 9
예제 #8
0
    def test_register_next_bracket(self, space: Space):
        """Check that a point is registered inside the good bracket when higher fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 3
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        force_observe(hyperband, trial)
        assert hyperband.brackets is not None
        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[0].rungs) == 0
        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[1].rungs) == 1
        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[2].rungs) == 0
        assert trial_id in hyperband.brackets[1].rungs[0]["results"]
        compare_registered_trial(
            hyperband.brackets[1].rungs[0]["results"][trial_id], trial)

        value = 51
        fidelity = 9
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        force_observe(hyperband, trial)

        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[0].rungs) == 0
        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[1].rungs) == 1
        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[2].rungs) == 1
        assert trial_id in hyperband.brackets[2].rungs[0]["results"]
        compare_registered_trial(
            hyperband.brackets[2].rungs[0]["results"][trial_id], trial)
예제 #9
0
    def test_register(self, hyperband: Hyperband, bracket: HyperbandBracket):
        """Check that a point is correctly registered inside a bracket."""
        assert bracket.owner is hyperband
        trial = create_trial_for_hb((1, 0.0), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        bracket.register(trial)

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert trial.objective is not None
        assert bracket.rungs[0]["results"][trial_id][
            0] == trial.objective.value
        assert bracket.rungs[0]["results"][trial_id][1].to_dict(
        ) == trial.to_dict()
예제 #10
0
    def test_promotion_with_rung_1_hit(self, hyperband: Hyperband,
                                       bracket: HyperbandBracket,
                                       rung_0: RungDict):
        """Test that get_candidate gives us the next best thing if point is already in rung 1."""
        trial = create_trial_for_hb((1, 0.0), None)
        assert bracket.owner is hyperband
        bracket.rungs[0] = rung_0
        assert trial.objective is not None
        bracket.rungs[1]["results"][hyperband.get_id(trial,
                                                     ignore_fidelity=True)] = (
                                                         trial.objective.value,
                                                         trial,
                                                     )

        trials = bracket.get_candidates(0)

        assert trials[0].params == create_trial_for_hb((1, 1), 0.0).params
예제 #11
0
    def test_suggest_inf_duplicates(
        self,
        monkeypatch,
        hyperband: Hyperband,
        bracket: HyperbandBracket,
        rung_0: RungDict,
        rung_1: RungDict,
        rung_2: RungDict,
    ):
        """Test that sampling inf collisions will return None."""
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband

        zhe_trial = create_trial_for_hb(("fidelity", 0.0))
        hyperband.trial_to_brackets[hyperband.get_id(zhe_trial,
                                                     ignore_fidelity=True)] = 0

        mock_samples(hyperband, [zhe_trial] * 2)

        assert hyperband.suggest(100) == []
예제 #12
0
    def test_register(
        self,
        hyperband: Hyperband,
        bracket: HyperbandBracket,
        rung_0: RungDict,
        rung_1: RungDict,
    ):
        """Check that a point is registered inside the bracket."""
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband
        bracket.rungs = [rung_0, rung_1]
        trial = create_trial_for_hb((1, 0.0), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        hyperband.observe([trial])

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
예제 #13
0
def force_observe(hyperband: Hyperband, trial: Trial) -> None:
    # hyperband.sampled.add(hashlib.md5(str(list(point)).encode("utf-8")).hexdigest())

    hyperband.register(trial)

    id_wo_fidelity = hyperband.get_id(trial, ignore_fidelity=True)

    bracket_index = hyperband.trial_to_brackets.get(id_wo_fidelity, None)

    if bracket_index is None:
        fidelity = flatten(trial.params)[hyperband.fidelity_index]
        assert hyperband.brackets is not None
        bracket_index = [
            i for i, bracket in enumerate(hyperband.brackets)
            if bracket.rungs[0]["resources"] == fidelity
        ][0]

    hyperband.trial_to_brackets[id_wo_fidelity] = bracket_index

    hyperband.observe([trial])
예제 #14
0
    def test_get_id_multidim(self):
        """Test valid id for points with dim of shape > 1"""
        space = Space()
        space.register(Fidelity("epoch", 1, 9, 3))
        space.register(Real("lr", "uniform", 0, 1, shape=2))

        hyperband = Hyperband(space)

        assert hyperband.get_id(["whatever", [1, 1]]) == hyperband.get_id(
            ["is here", [1, 1]]
        )
        assert hyperband.get_id(["whatever", [1, 1]]) != hyperband.get_id(
            ["is here", [2, 2]]
        )
        assert hyperband.get_id(
            ["whatever", [1, 1]], ignore_fidelity=False
        ) != hyperband.get_id(["is here", [1, 1]], ignore_fidelity=False)
        assert hyperband.get_id(
            ["whatever", [1, 1]], ignore_fidelity=False
        ) != hyperband.get_id(["is here", [2, 2]], ignore_fidelity=False)
        assert hyperband.get_id(
            ["same", [1, 1]], ignore_fidelity=False
        ) == hyperband.get_id(["same", [1, 1]], ignore_fidelity=False)
        assert hyperband.get_id(
            ["same", [1, 1]], ignore_fidelity=False
        ) != hyperband.get_id(["same", [1, 1]])