Exemple #1
0
    def test_register_next_bracket(self, space):
        """Check that a point is registered inside the good bracket when higher fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 3
        point = (fidelity, value)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        hyperband.observe([point], [{'objective': 0.0}])

        assert sum(len(rung['results']) for rung in hyperband.brackets[0].rungs) == 0
        assert sum(len(rung['results']) for rung in hyperband.brackets[1].rungs) == 1
        assert sum(len(rung['results']) for rung in hyperband.brackets[2].rungs) == 0
        assert point_hash in hyperband.brackets[1].rungs[0]['results']
        assert (0.0, point) == hyperband.brackets[1].rungs[0]['results'][point_hash]

        value = 51
        fidelity = 9
        point = (fidelity, value)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        hyperband.observe([point], [{'objective': 0.0}])

        assert sum(len(rung['results']) for rung in hyperband.brackets[0].rungs) == 0
        assert sum(len(rung['results']) for rung in hyperband.brackets[1].rungs) == 1
        assert sum(len(rung['results']) for rung in hyperband.brackets[2].rungs) == 1
        assert point_hash in hyperband.brackets[2].rungs[0]['results']
        assert (0.0, point) == hyperband.brackets[2].rungs[0]['results'][point_hash]
Exemple #2
0
    def test_register_bracket_multi_fidelity(self, space):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 1
        point = (fidelity, value)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        hyperband.observe([point], [{'objective': 0.0}])

        bracket = hyperband.brackets[0]

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[0]['results']
        assert (0.0, point) == bracket.rungs[0]['results'][point_hash]

        fidelity = 3
        point = [fidelity, value]
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        hyperband.observe([point], [{'objective': 0.0}])

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[1]['results']
        assert (0.0, point) != bracket.rungs[0]['results'][point_hash]
        assert (0.0, point) == bracket.rungs[1]['results'][point_hash]
Exemple #3
0
    def test_suggest_in_finite_cardinality(self):
        """Test that suggest None when search space is empty"""
        space = Space()
        space.register(Integer('yolo1', 'uniform', 0, 6))
        space.register(Fidelity('epoch', 1, 9, 3))

        hyperband = Hyperband(space, repetitions=1)
        for i in range(6):
            hyperband.observe([(1, i)], [{'objective': i}])

        assert hyperband.suggest() is None
Exemple #4
0
    def test_register_invalid_fidelity(self, space):
        """Check that a point cannot registered if fidelity is invalid."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 2
        point = (fidelity, value)

        with pytest.raises(ValueError) as ex:
            hyperband.observe([point], [{'objective': 0.0}])

        assert 'No bracket found for point' in str(ex.value)
Exemple #5
0
    def test_register_invalid_fidelity(self, space: Space):
        """Check that a point cannot registered if fidelity is invalid."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        hyperband.observe([trial])

        assert not hyperband.has_suggested(trial)
        assert not hyperband.has_observed(trial)
Exemple #6
0
    def test_register_not_sampled(self, space, caplog):
        """Check that a point cannot registered if not sampled."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 2
        point = (fidelity, value)

        with caplog.at_level(logging.INFO, logger="orion.algo.hyperband"):
            hyperband.observe([point], [{"objective": 0.0}])

        assert len(caplog.records) == 1
        assert "Ignoring point" in caplog.records[0].msg
Exemple #7
0
    def test_register_not_sampled(self, space, caplog):
        """Check that a point cannot registered if not sampled."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        with caplog.at_level(logging.DEBUG, logger="orion.algo.hyperband"):
            hyperband.observe([trial])

        assert len(caplog.records) == 1
        assert "Ignoring trial" in caplog.records[0].msg
Exemple #8
0
    def test_register_corrupted_db(self, caplog, space):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 3
        point = (fidelity, value)

        hyperband.observe([point], [{'objective': 0.0}])
        assert 'Point registered to wrong bracket' not in caplog.text

        fidelity = 1
        point = [fidelity, value]

        caplog.clear()
        hyperband.observe([point], [{'objective': 0.0}])
        assert 'Point registered to wrong bracket' in caplog.text
Exemple #9
0
    def test_register(
        self,
        hyperband: Hyperband,
        bracket: HyperbandBracket,
        rung_0: RungDict,
        rung_1: RungDict,
    ):
        """Check that a point is registered inside the bracket."""
        hyperband.brackets = [bracket]
        assert bracket.owner is hyperband
        bracket.rungs = [rung_0, rung_1]
        trial = create_trial_for_hb((1, 0.0), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        hyperband.observe([trial])

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
Exemple #10
0
def force_observe(hyperband: Hyperband, trial: Trial) -> None:
    # hyperband.sampled.add(hashlib.md5(str(list(point)).encode("utf-8")).hexdigest())

    hyperband.register(trial)

    id_wo_fidelity = hyperband.get_id(trial, ignore_fidelity=True)

    bracket_index = hyperband.trial_to_brackets.get(id_wo_fidelity, None)

    if bracket_index is None:
        fidelity = flatten(trial.params)[hyperband.fidelity_index]
        assert hyperband.brackets is not None
        bracket_index = [
            i for i, bracket in enumerate(hyperband.brackets)
            if bracket.rungs[0]["resources"] == fidelity
        ][0]

    hyperband.trial_to_brackets[id_wo_fidelity] = bracket_index

    hyperband.observe([trial])
Exemple #11
0
    def test_full_process(self, monkeypatch, hyperband: Hyperband):
        """Test Hyperband full process."""
        sample_trials = [
            create_trial_for_hb(("fidelity", i)) for i in range(100)
        ]

        hyperband._refresh_brackets()
        mock_samples(hyperband, copy.deepcopy(sample_trials))

        # Fill all brackets' first rung
        first_rung = hyperband.suggest(100)
        assert first_rung is not None
        first_bracket_first_rung = first_rung[6:]
        second_bracket_first_rung = first_rung[3:6]
        third_bracket_first_rung = first_rung[:3]

        compare_trials(
            first_bracket_first_rung,
            [create_trial_for_hb((1, i)) for i in range(6, 15)],
        )
        compare_trials(
            second_bracket_first_rung,
            [create_trial_for_hb((3, i)) for i in range(3, 6)],
        )
        compare_trials(third_bracket_first_rung,
                       [create_trial_for_hb((9, i)) for i in range(3)])
        assert hyperband.brackets is not None
        assert hyperband.brackets[0].has_rung_filled(0)
        assert not hyperband.brackets[0].is_ready()
        assert hyperband.suggest(100) == []
        assert hyperband.suggest(100) == []

        # Observe first bracket first rung
        for i, trial in enumerate(first_bracket_first_rung):
            trial.status = "completed"
            trial._results.append(
                Trial.Result(name="objective", type="objective", value=16 - i))
        hyperband.observe(first_bracket_first_rung)

        assert hyperband.brackets[0].is_ready()
        assert not hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        # Promote first bracket first rung
        first_bracket_second_rung = hyperband.suggest(100)
        compare_trials(
            first_bracket_second_rung,
            [create_trial_for_hb((3, 3 + 3 + 9 - 1 - i)) for i in range(3)],
        )

        assert hyperband.brackets[0].has_rung_filled(1)
        assert not hyperband.brackets[0].is_ready()
        assert not hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        assert first_bracket_second_rung is not None
        # Observe first bracket second rung
        for i, trial in enumerate(first_bracket_second_rung):
            trial.status = "completed"
            trial._results.append(
                Trial.Result(name="objective", type="objective", value=8 - i))
        hyperband.observe(first_bracket_second_rung)

        assert hyperband.brackets[0].is_ready()
        assert not hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        # Promote first bracket second rung
        first_bracket_third_rung = hyperband.suggest(100)
        compare_trials(first_bracket_third_rung,
                       [create_trial_for_hb((9, 12))])

        assert hyperband.brackets[0].has_rung_filled(2)
        assert not hyperband.brackets[0].is_ready()
        assert not hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        # Observe second bracket first rung
        for i, trial in enumerate(second_bracket_first_rung):
            trial.status = "completed"
            trial._results.append(
                Trial.Result(name="objective", type="objective", value=8 - i))
        hyperband.observe(second_bracket_first_rung)

        assert not hyperband.brackets[0].is_ready()
        assert hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        # Promote second bracket first rung
        second_bracket_second_rung = hyperband.suggest(100)
        compare_trials(second_bracket_second_rung,
                       [create_trial_for_hb((9, 5))])

        assert not hyperband.brackets[0].is_ready()
        assert hyperband.brackets[1].has_rung_filled(1)
        assert not hyperband.brackets[1].is_ready()
        assert not hyperband.brackets[2].is_ready()

        # Observe third bracket first rung
        for i, trial in enumerate(third_bracket_first_rung):
            trial.status = "completed"
            trial._results.append(
                Trial.Result(name="objective", type="objective", value=3 - i))
        hyperband.observe(third_bracket_first_rung)

        assert not hyperband.brackets[0].is_ready(2)
        assert not hyperband.brackets[1].is_ready(1)
        assert hyperband.brackets[2].is_ready(0)
        assert hyperband.brackets[2].is_done

        assert second_bracket_second_rung is not None
        # Observe second bracket second rung
        for i, trial in enumerate(second_bracket_second_rung):
            trial.status = "completed"
            trial._results.append(
                Trial.Result(name="objective", type="objective", value=5 - i))
        hyperband.observe(second_bracket_second_rung)

        assert not hyperband.brackets[0].is_ready(2)
        assert hyperband.brackets[1].is_ready(1)
        assert hyperband.brackets[1].is_done
        assert first_bracket_third_rung is not None
        hyperband.observe(first_bracket_third_rung)

        assert hyperband.is_done
        assert hyperband.brackets[0].is_done
        assert hyperband.suggest(100) == []

        # Refresh repeat and execution times property
        monkeypatch.setattr(hyperband, "repetitions", 2)
        # monkeypatch.setattr(hyperband.brackets[0], "repetition_id", 0)
        # hyperband.observe([(9, 12)], [{"objective": 3 - i}])
        assert len(hyperband.brackets) == 3
        hyperband._refresh_brackets()
        assert len(hyperband.brackets) == 6
        mock_samples(hyperband,
                     copy.deepcopy(sample_trials[:3] + sample_trials))
        trials = hyperband.suggest(100)
        assert not hyperband.is_done
        assert not hyperband.brackets[3].is_ready(2)
        assert not hyperband.brackets[3].is_done
        assert trials is not None
        compare_trials(trials[:3],
                       map(create_trial_for_hb, [(9, 3), (9, 4), (9, 6)]))
        compare_trials(trials[3:6],
                       map(create_trial_for_hb, [(3, 7), (3, 8), (3, 9)]))
        compare_trials(trials[6:],
                       [create_trial_for_hb((1, i)) for i in range(15, 24)])