Ejemplo n.º 1
0
    def test_candidate_promotion(self, hyperband: Hyperband,
                                 bracket: HyperbandBracket, rung_0: RungDict):
        """Test that correct point is promoted."""
        assert bracket.owner is hyperband
        bracket.rungs[0] = rung_0

        points = bracket.get_candidates(0)

        assert points[0].params == create_trial_for_hb((1, 0.0), 0.0).params
Ejemplo n.º 2
0
    def test_bad_register(self, hyperband: Hyperband,
                          bracket: HyperbandBracket):
        """Check that a non-valid point is not registered."""
        assert bracket.owner is hyperband

        with pytest.raises(IndexError) as ex:
            bracket.register(create_trial_for_hb((55, 0.0), 0.0))

        assert "Bad fidelity level 55" in str(ex.value)
Ejemplo n.º 3
0
    def test_get_trial_max_resource(
        self,
        hyperband: Hyperband,
        bracket: HyperbandBracket,
        rung_0: RungDict,
        rung_1: RungDict,
        rung_2: RungDict,
    ):
        """Test to get the max resource R for a particular trial"""
        assert bracket.owner is hyperband
        bracket.rungs[0] = rung_0

        assert bracket.get_trial_max_resource(
            trial=create_trial_for_hb((1, 0.0))) == 1
        assert bracket.get_trial_max_resource(
            trial=create_trial_for_hb((1, 8.0))) == 1

        bracket.rungs[1] = rung_1
        assert bracket.get_trial_max_resource(
            trial=create_trial_for_hb((1, 0.0))) == 3
        assert bracket.get_trial_max_resource(
            trial=create_trial_for_hb((1, 8.0))) == 1

        bracket.rungs[2] = rung_2
        assert bracket.get_trial_max_resource(
            trial=create_trial_for_hb((1, 0.0))) == 9
        assert bracket.get_trial_max_resource(
            trial=create_trial_for_hb((1, 8.0))) == 1
Ejemplo n.º 4
0
    def test_suggest_duplicates_between_execution(self, monkeypatch, hyperband,
                                                  budgets):
        """Test that sampling collisions are handled between different hyperband execution."""
        hyperband.repetitions = 2
        bracket = HyperbandBracket(hyperband, budgets, 1)
        hyperband.brackets = [bracket]
        bracket.hyperband = hyperband

        for i in range(9):
            force_observe(hyperband, create_trial_for_hb((1, i), objective=i))

        for i in range(3):
            force_observe(hyperband, create_trial_for_hb((3, i), objective=i))

        force_observe(hyperband, create_trial_for_hb((9, 0), objective=0))

        assert not hyperband.is_done

        # lr:7 and lr:8 are already sampled in first repetition, they should not be present
        # in second repetition. Samples with lr:7 and lr:8 will be ignored.

        # (9, 0) already exists
        candidates_for_epoch_9_bracket = [(9, 0), (9, 2), (9, 3), (9, 10)]
        # (9, 1) -> (3, 1) already promoted in last repetition
        # (9, 3) sampled for previous bracket
        candidates_for_epoch_3_bracket = [(9, 1), (9, 3), (9, 4), (9, 5),
                                          (9, 11)]
        # (9, 0) -> (1, 0) already sampled in last repetition
        # (9, 8) -> (1, 8) already sampled in last repetition
        candidates_for_epoch_1_bracket = [(9, 0), (9, 8), (9, 12), (9, 13)]

        zhe_point = list(
            map(
                create_trial_for_hb,
                candidates_for_epoch_9_bracket +
                candidates_for_epoch_3_bracket +
                candidates_for_epoch_1_bracket,
            ))

        hyperband._refresh_brackets()
        mock_samples(hyperband, zhe_point)
        zhe_samples = hyperband.suggest(100)
        assert len(zhe_samples) == 8
        assert zhe_samples[0].params == {"epoch": 9, "lr": 2}
        assert zhe_samples[1].params == {"epoch": 9, "lr": 3}
        assert zhe_samples[2].params == {"epoch": 9, "lr": 10}
        assert zhe_samples[3].params == {"epoch": 3, "lr": 4}
        assert zhe_samples[4].params == {"epoch": 3, "lr": 5}
        assert zhe_samples[5].params == {"epoch": 3, "lr": 11}
        assert zhe_samples[6].params == {"epoch": 1, "lr": 12}
        assert zhe_samples[7].params == {"epoch": 1, "lr": 13}
Ejemplo n.º 5
0
    def test_repr(
        self,
        bracket: HyperbandBracket,
        rung_0: RungDict,
        rung_1: RungDict,
        rung_2: RungDict,
    ):
        """Test the string representation of HyperbandBracket"""
        bracket.rungs[0] = rung_0
        bracket.rungs[1] = rung_1
        bracket.rungs[2] = rung_2

        assert str(
            bracket) == "HyperbandBracket(resource=[1, 3, 9], repetition id=1)"
Ejemplo n.º 6
0
    def test_update_rungs_return_candidate(self, hyperband: Hyperband,
                                           bracket: HyperbandBracket,
                                           rung_1: RungDict):
        """Check if a valid modified candidate is returned by update_rungs."""
        assert bracket.owner is hyperband
        bracket.rungs[1] = rung_1
        trial = create_trial_for_hb((3, 0.0), 0.0)

        candidates = bracket.promote(1)

        trial_id = hyperband.get_id(trial, ignore_fidelity=True)
        assert trial_id in bracket.rungs[1]["results"]
        assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
        assert candidates[0].params["epoch"] == 9
Ejemplo n.º 7
0
    def test_register(self, hyperband: Hyperband, bracket: HyperbandBracket):
        """Check that a point is correctly registered inside a bracket."""
        assert bracket.owner is hyperband
        trial = create_trial_for_hb((1, 0.0), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        bracket.register(trial)

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert trial.objective is not None
        assert bracket.rungs[0]["results"][trial_id][
            0] == trial.objective.value
        assert bracket.rungs[0]["results"][trial_id][1].to_dict(
        ) == trial.to_dict()
Ejemplo n.º 8
0
    def test_no_promotion_if_not_completed(self, hyperband: Hyperband,
                                           bracket: HyperbandBracket,
                                           rung_0: RungDict):
        """Test the get_candidate return None if trials are not completed."""
        assert bracket.owner is hyperband
        bracket.rungs[0] = rung_0
        rung = bracket.rungs[0]["results"]

        # points = bracket.get_candidates(0)

        for p_id in rung.keys():
            rung[p_id] = (None, rung[p_id][1])

        with pytest.raises(TypeError):
            bracket.get_candidates(0)
Ejemplo n.º 9
0
    def test_no_promotion_when_rung_full(
        self,
        hyperband: Hyperband,
        bracket: HyperbandBracket,
        rung_0: RungDict,
        rung_1: RungDict,
    ):
        """Test that get_candidate returns `None` if rung 1 is full."""
        assert bracket.owner is hyperband
        bracket.rungs[0] = rung_0
        bracket.rungs[1] = rung_1

        points = bracket.get_candidates(0)

        assert points == []
Ejemplo n.º 10
0
    def test_suggest_promote_identic_objectives(self, hyperband: Hyperband,
                                                bracket: HyperbandBracket):
        """Test that identic objectives are handled properly"""
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband

        n_trials = 9
        resources = 1

        results = {}
        for param in np.linspace(0, 8, 9):
            trial = create_trial_for_hb((resources, param), objective=0)
            trial_hash = trial.compute_trial_hash(
                trial,
                ignore_fidelity=True,
                ignore_experiment=True,
            )
            assert trial.objective is not None
            results[trial_hash] = (trial.objective.value, trial)

        bracket.rungs[0] = RungDict(n_trials=n_trials,
                                    resources=resources,
                                    results=results)

        candidates = hyperband.suggest(2)
        assert candidates is not None
        assert len(candidates) == 2
        assert (sum(1 for trial in candidates
                    if trial.params[hyperband.fidelity_index] == 3) == 2)
Ejemplo n.º 11
0
    def test_suggest_duplicates_between_calls(self, monkeypatch,
                                              hyperband: Hyperband,
                                              bracket: HyperbandBracket):
        """Test that same trials are not allowed in different suggest call of
        the same hyperband execution.
        """
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband

        fidelity = 1

        duplicate_trial = create_trial_for_hb((fidelity, 0.0))
        new_trial = create_trial_for_hb((fidelity, 0.5))

        duplicate_id = hyperband.get_id(duplicate_trial, ignore_fidelity=True)
        bracket.rungs[0]["results"] = {
            duplicate_id: (0.0, duplicate_trial)
        }  # type: ignore

        hyperband.trial_to_brackets[duplicate_id] = 0

        trials = [duplicate_trial, new_trial]

        mock_samples(
            hyperband,
            trials +
            [create_trial_for_hb((fidelity, i)) for i in range(10 - 2)],
        )
        trials = hyperband.suggest(100)
        assert trials is not None
        assert trials[0].params == new_trial.params
Ejemplo n.º 12
0
    def test_promotion_with_rung_1_hit(self, hyperband: Hyperband,
                                       bracket: HyperbandBracket,
                                       rung_0: RungDict):
        """Test that get_candidate gives us the next best thing if point is already in rung 1."""
        trial = create_trial_for_hb((1, 0.0), None)
        assert bracket.owner is hyperband
        bracket.rungs[0] = rung_0
        assert trial.objective is not None
        bracket.rungs[1]["results"][hyperband.get_id(trial,
                                                     ignore_fidelity=True)] = (
                                                         trial.objective.value,
                                                         trial,
                                                     )

        trials = bracket.get_candidates(0)

        assert trials[0].params == create_trial_for_hb((1, 1), 0.0).params
Ejemplo n.º 13
0
    def test_update_rungs_return_no_candidate(self, hyperband: Hyperband,
                                              bracket: HyperbandBracket,
                                              rung_1: RungDict):
        """Check if no candidate is returned by update_rungs."""
        assert bracket.owner is hyperband

        candidates = bracket.promote(1)

        assert candidates == []
Ejemplo n.º 14
0
    def test_is_done(self, bracket: HyperbandBracket, rung_0: RungDict):
        """Test that the `is_done` property works."""
        assert not bracket.is_done

        # Actual value of the point is not important here
        bracket.rungs[2]["results"] = {
            "1": (1, 0.0),
            "2": (1, 0.0),
            "3": (1, 0.0)
        }  # type: ignore

        assert bracket.is_done
Ejemplo n.º 15
0
    def test_suggest_promote(self, hyperband: Hyperband,
                             bracket: HyperbandBracket, rung_0: RungDict):
        """Test that correct point is promoted and returned."""
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband
        bracket.rungs[0] = rung_0

        points = hyperband.suggest(100)
        assert points is not None
        assert len(points) == 3
        assert points[0].params == {"epoch": 3, "lr": 0}
        assert points[1].params == {"epoch": 3, "lr": 1}
        assert points[2].params == {"epoch": 3, "lr": 2}
Ejemplo n.º 16
0
    def test_suggest_opt_out(
        self,
        hyperband: Hyperband,
        bracket: HyperbandBracket,
        rung_0: RungDict,
        rung_1: RungDict,
        rung_2: RungDict,
    ):
        """Test that Hyperband opts out when rungs are not ready."""
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband

        bracket.rungs[0] = rung_0

        trial_id = next(iter(rung_0["results"].keys()))
        objective, point = rung_0["results"][trial_id]
        rung_0["results"][trial_id] = (None, point)

        assert hyperband.suggest(100) == []
Ejemplo n.º 17
0
    def test_register(
        self,
        hyperband: Hyperband,
        bracket: HyperbandBracket,
        rung_0: RungDict,
        rung_1: RungDict,
    ):
        """Check that a point is registered inside the bracket."""
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband
        bracket.rungs = [rung_0, rung_1]
        trial = create_trial_for_hb((1, 0.0), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        hyperband.observe([trial])

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
Ejemplo n.º 18
0
def bracket(budgets: list[BudgetTuple], hyperband: Hyperband):
    """Return a `HyperbandBracket` instance configured with `b_config`."""
    return HyperbandBracket(hyperband, budgets, 1)
Ejemplo n.º 19
0
    def test_is_ready(
        self,
        hyperband: Hyperband,
        bracket: HyperbandBracket,
        rung_0: RungDict,
        rung_1: RungDict,
        rung_2: RungDict,
    ):
        """Test that Hyperband bracket detects when rung is ready."""
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband
        bracket.rungs[0] = rung_0

        rung = bracket.rungs[0]["results"]
        trial_id = next(iter(rung.keys()))
        objective, point = rung[trial_id]
        rung[trial_id] = (None, point)

        assert not bracket.is_ready()
        assert not bracket.is_ready(0)

        rung[trial_id] = (objective, point)

        assert bracket.is_ready()
        assert bracket.is_ready(0)
        assert not bracket.is_ready(1)
        assert not bracket.is_ready(2)

        bracket.rungs[1] = rung_1

        rung = bracket.rungs[1]["results"]
        trial_id = next(iter(rung.keys()))
        objective, point = rung[trial_id]
        rung[trial_id] = (None, point)

        assert not bracket.is_ready(
        )  # Should depend on last rung that contains trials
        assert bracket.is_ready(0)
        assert not bracket.is_ready(1)
        assert not bracket.is_ready(2)

        rung[trial_id] = (objective, point)

        assert bracket.is_ready(
        )  # Should depend on last rung that contains trials
        assert bracket.is_ready(0)
        assert bracket.is_ready(1)
        assert not bracket.is_ready(2)

        bracket.rungs[2] = rung_2

        rung = bracket.rungs[2]["results"]
        trial_id = next(iter(rung.keys()))
        objective, point = rung[trial_id]
        rung[trial_id] = (None, point)

        assert not bracket.is_ready(
        )  # Should depend on last rung that contains trials
        assert bracket.is_ready(0)
        assert bracket.is_ready(1)
        assert not bracket.is_ready(2)

        rung[trial_id] = (objective, point)

        assert bracket.is_ready(
        )  # Should depend on last rung that contains trials
        assert bracket.is_ready(0)
        assert bracket.is_ready(1)
        assert bracket.is_ready(2)
Ejemplo n.º 20
0
    def test_is_filled(
        self,
        hyperband: Hyperband,
        bracket: HyperbandBracket,
        rung_0: RungDict,
        rung_1: RungDict,
        rung_2: RungDict,
    ):
        """Test that Hyperband bracket detects when rung is filled."""
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband
        bracket.rungs[0] = rung_0

        rung = bracket.rungs[0]["results"]
        trial_id = next(iter(rung.keys()))
        objective, point = rung.pop(trial_id)

        assert not bracket.is_filled
        assert not bracket.has_rung_filled(0)

        rung[trial_id] = (objective, point)

        assert bracket.is_filled
        assert bracket.has_rung_filled(0)
        assert not bracket.has_rung_filled(1)
        assert not bracket.has_rung_filled(2)

        bracket.rungs[1] = rung_1

        rung = bracket.rungs[1]["results"]
        trial_id = next(iter(rung.keys()))
        objective, point = rung.pop(trial_id)

        assert bracket.is_filled  # Should depend first rung only
        assert bracket.has_rung_filled(0)
        assert not bracket.has_rung_filled(1)

        rung[trial_id] = (objective, point)

        assert bracket.is_filled  # Should depend first rung only
        assert bracket.has_rung_filled(0)
        assert bracket.has_rung_filled(1)
        assert not bracket.has_rung_filled(2)

        bracket.rungs[2] = rung_2

        rung = bracket.rungs[2]["results"]
        trial_id = next(iter(rung.keys()))
        objective, point = rung.pop(trial_id)

        assert bracket.is_filled  # Should depend first rung only
        assert bracket.has_rung_filled(0)
        assert bracket.has_rung_filled(1)
        assert not bracket.has_rung_filled(2)

        rung[trial_id] = (objective, point)

        assert bracket.is_filled  # Should depend first rung only
        assert bracket.has_rung_filled(0)
        assert bracket.has_rung_filled(1)
        assert bracket.has_rung_filled(2)