示例#1
0
    def test_register_bracket_multi_fidelity(self, space, b_config):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 1
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        bracket = asha.brackets[0]

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params

        fidelity = 3
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        assert len(bracket.rungs[1])
        assert trial_id in bracket.rungs[1]["results"]
        assert bracket.rungs[0]["results"][trial_id][1].params != trial.params
        assert bracket.rungs[1]["results"][trial_id][0] == 0.0
        assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
示例#2
0
    def test_suggest_promote_identic_objectives(
        self,
        asha: ASHA,
        bracket: ASHABracket,
        big_rung_0: RungDict,
        big_rung_1: RungDict,
    ):
        """Test that identic objectives are handled properly"""
        asha.brackets = [bracket]
        bracket.owner = asha

        n_trials = 9
        resources = 1

        results = {}
        for param in np.linspace(0, 8, 9):
            trial = create_trial_for_hb((resources, param), objective=0)
            trial_hash = trial.compute_trial_hash(
                trial,
                ignore_fidelity=True,
                ignore_experiment=True,
            )
            assert trial.objective is not None
            results[trial_hash] = (trial.objective.value, trial)

        bracket.rungs[0] = RungDict(n_trials=n_trials,
                                    resources=resources,
                                    results=results)

        candidates = asha.suggest(2)

        assert len(candidates) == 2
        assert (sum(1 for trial in candidates
                    if trial.params[asha.fidelity_index] == 3) == 2)
示例#3
0
    def test_suggest_in_finite_cardinality(self):
        """Test that suggest None when search space is empty"""
        space = Space()
        space.register(Integer("yolo1", "uniform", 0, 5))
        space.register(Fidelity("epoch", 1, 9, 3))

        asha = ASHA(space)
        for i in range(6):
            force_observe(
                asha,
                create_trial(
                    (1, i),
                    names=("epoch", "yolo1"),
                    types=("fidelity", "integer"),
                    results={"objective": i},
                ),
            )

        for i in range(2):
            force_observe(
                asha,
                create_trial(
                    (3, i),
                    names=("epoch", "yolo1"),
                    types=("fidelity", "integer"),
                    results={"objective": i},
                ),
            )

        assert asha.suggest(1) == []
示例#4
0
    def test_get_id(self, space, b_config):
        """Test valid id of points"""
        asha = ASHA(space,
                    max_resources=b_config['R'],
                    grace_period=b_config['r'],
                    reduction_factor=b_config['eta'],
                    num_brackets=3)

        assert asha.get_id([1, 'whatever']) == asha.get_id([1, 'is here'])
        assert asha.get_id([1, 'whatever']) != asha.get_id([2, 'is here'])
示例#5
0
    def test_suggest_promote(self, asha: ASHA, bracket: ASHABracket,
                             rung_0: RungDict):
        """Test that correct point is promoted and returned."""
        asha.brackets = [bracket]
        assert bracket.owner is asha
        bracket.rungs[0] = rung_0

        trials = asha.suggest(1)

        assert trials[0].params == {"epoch": 3, "lr": 0.0}
示例#6
0
    def test_register_invalid_fidelity(self, space, b_config):
        """Check that a point cannot registered if fidelity is invalid."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        point = (fidelity, value)

        with pytest.raises(ValueError) as ex:
            asha.observe([point], [{'objective': 0.0}])

        assert 'No bracket found for point' in str(ex.value)
示例#7
0
    def test_register_not_sampled(self, space, b_config, caplog):
        """Check that a point cannot registered if not sampled."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        with caplog.at_level(logging.DEBUG, logger="orion.algo.hyperband"):
            asha.observe([trial])

        assert len(caplog.records) == 1
        assert "Ignoring trial" in caplog.records[0].msg
示例#8
0
    def test_register_not_sampled(self, space, b_config, caplog):
        """Check that a point cannot registered if not sampled."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        point = (fidelity, value)

        with caplog.at_level(logging.INFO, logger="orion.algo.asha"):
            asha.observe([point], [{"objective": 0.0}])

        assert len(caplog.records) == 1
        assert "Ignoring point" in caplog.records[0].msg
示例#9
0
    def test_suggest_in_finite_cardinality(self):
        """Test that suggest None when search space is empty"""
        space = Space()
        space.register(Integer("yolo1", "uniform", 0, 5))
        space.register(Fidelity("epoch", 1, 9, 3))

        asha = ASHA(space)
        for i in range(6):
            force_observe(asha, (1, i), {"objective": i})

        for i in range(2):
            force_observe(asha, (3, i), {"objective": i})

        assert asha.suggest() is None
示例#10
0
    def test_register(self, asha: ASHA, bracket: ASHABracket, rung_0: RungDict,
                      rung_1: RungDict):
        """Check that a point is registered inside the bracket."""
        asha.brackets = [bracket]
        assert bracket.owner is asha
        bracket.rungs = [rung_0, rung_1]
        trial = create_trial_for_hb((1, 0.0), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        asha.observe([trial])

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
示例#11
0
    def test_register_next_bracket(self, space, b_config):
        """Check that a point is registered inside the good bracket when higher fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 3
        point = (fidelity, value)
        point_hash = hashlib.md5(str([value]).encode("utf-8")).hexdigest()

        force_observe(asha, point, {"objective": 0.0})

        assert sum(len(rung[1]) for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung[1]) for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung[1]) for rung in asha.brackets[2].rungs) == 0
        assert point_hash in asha.brackets[1].rungs[0][1]
        assert (0.0, point) == asha.brackets[1].rungs[0][1][point_hash]

        value = 51
        fidelity = 9
        point = (fidelity, value)
        point_hash = hashlib.md5(str([value]).encode("utf-8")).hexdigest()

        force_observe(asha, point, {"objective": 0.0})

        assert sum(len(rung[1]) for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung[1]) for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung[1]) for rung in asha.brackets[2].rungs) == 1
        assert point_hash in asha.brackets[2].rungs[0][1]
        assert (0.0, point) == asha.brackets[2].rungs[0][1][point_hash]
示例#12
0
    def test_register_bracket_multi_fidelity(self, space, b_config):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 1
        point = (fidelity, value)
        point_hash = hashlib.md5(str([value]).encode("utf-8")).hexdigest()

        force_observe(asha, point, {"objective": 0.0})

        bracket = asha.brackets[0]

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[0][1]
        assert (0.0, point) == bracket.rungs[0][1][point_hash]

        fidelity = 3
        point = [fidelity, value]
        point_hash = hashlib.md5(str([value]).encode("utf-8")).hexdigest()

        force_observe(asha, point, {"objective": 0.0})

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[1][1]
        assert (0.0, point) != bracket.rungs[0][1][point_hash]
        assert (0.0, point) == bracket.rungs[1][1][point_hash]
示例#13
0
    def test_register_invalid_fidelity(self, space, b_config):
        """Check that a point cannot registered if fidelity is invalid."""
        asha = ASHA(space,
                    max_resources=b_config['R'],
                    grace_period=b_config['r'],
                    reduction_factor=b_config['eta'],
                    num_brackets=3)

        value = 50
        fidelity = 2
        point = (value, fidelity)

        with pytest.raises(ValueError) as ex:
            asha.observe([point], [{'objective': 0.0}])

        assert 'No bracket found for point' in str(ex.value)
示例#14
0
    def test_suggest_promote_many(
        self,
        asha: ASHA,
        bracket: ASHABracket,
        big_rung_0: RungDict,
        big_rung_1: RungDict,
    ):
        """Test that correct points are promoted and returned."""
        asha.brackets = [bracket]
        assert bracket.owner is asha
        bracket.rungs[0] = big_rung_0
        bracket.rungs[1] = big_rung_1

        candidates = asha.suggest(3)

        assert len(candidates) == 2 + 1
        assert (sum(1 for trial in candidates
                    if trial.params[asha.fidelity_index] == 9) == 2)
        assert (sum(1 for trial in candidates
                    if trial.params[asha.fidelity_index] == 3) == 1)
示例#15
0
    def test_suggest_new(
        self,
        monkeypatch,
        asha: ASHA,
        bracket: ASHABracket,
        rung_0: RungDict,
        rung_1: RungDict,
        rung_2: RungDict,
    ):
        """Test that a new point is sampled."""
        asha.brackets = [bracket]
        assert bracket.owner is asha

        def sample(num=1, seed=None):
            return [create_trial_for_hb(("fidelity", 0.5))]

        monkeypatch.setattr(asha.space, "sample", sample)

        trials = asha.suggest(1)

        assert trials[0].params == {"epoch": 1, "lr": 0.5}
示例#16
0
    def test_register_invalid_fidelity(self, space, b_config):
        """Check that a point cannot registered if fidelity is invalid."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        point = (fidelity, value)

        with pytest.raises(ValueError) as ex:
            force_observe(asha, point, {"objective": 0.0})

        assert "No bracket found for point" in str(ex.value)
示例#17
0
    def test_register_bracket_multi_fidelity(self, space, b_config):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        asha = ASHA(space,
                    max_resources=b_config['R'],
                    grace_period=b_config['r'],
                    reduction_factor=b_config['eta'],
                    num_brackets=3)

        value = 50
        fidelity = 1
        point = (value, fidelity)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        asha.observe([point], [{'objective': 0.0}])

        bracket = asha.brackets[0]

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[0][1]
        assert (0.0, point) == bracket.rungs[0][1][point_hash]

        fidelity = 3
        point = [value, fidelity]
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        asha.observe([point], [{'objective': 0.0}])

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[1][1]
        assert (0.0, point) != bracket.rungs[0][1][point_hash]
        assert (0.0, point) == bracket.rungs[1][1][point_hash]
示例#18
0
    def test_register_next_bracket(self, space, b_config):
        """Check that a point is registered inside the good bracket when higher fidelity."""
        asha = ASHA(space,
                    max_resources=b_config['R'],
                    grace_period=b_config['r'],
                    reduction_factor=b_config['eta'],
                    num_brackets=3)

        value = 50
        fidelity = 3
        point = (value, fidelity)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        asha.observe([point], [{'objective': 0.0}])

        assert sum(len(rung[1]) for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung[1]) for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung[1]) for rung in asha.brackets[2].rungs) == 0
        assert point_hash in asha.brackets[1].rungs[0][1]
        assert (0.0, point) == asha.brackets[1].rungs[0][1][point_hash]

        value = 51
        fidelity = 9
        point = (value, fidelity)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        asha.observe([point], [{'objective': 0.0}])

        assert sum(len(rung[1]) for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung[1]) for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung[1]) for rung in asha.brackets[2].rungs) == 1
        assert point_hash in asha.brackets[2].rungs[0][1]
        assert (0.0, point) == asha.brackets[2].rungs[0][1][point_hash]
示例#19
0
    def test_update_rungs_return_candidate(self, asha: ASHA,
                                           bracket: ASHABracket,
                                           rung_1: RungDict):
        """Check if a valid modified candidate is returned by update_rungs."""
        assert bracket.owner is asha
        bracket.rungs[1] = rung_1
        trial = create_trial_for_hb((3, 0.0), 0.0)

        candidate = bracket.promote(1)[0]

        trial_id = asha.get_id(trial, ignore_fidelity=True)
        assert trial_id in bracket.rungs[1]["results"]
        assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
        assert candidate.params["epoch"] == 9
示例#20
0
    def test_register_next_bracket(self, space, b_config):
        """Check that a point is registered inside the good bracket when higher fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 3
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        assert sum(len(rung["results"])
                   for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung["results"])
                   for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung["results"])
                   for rung in asha.brackets[2].rungs) == 0
        assert trial_id in asha.brackets[1].rungs[0]["results"]
        compare_registered_trial(
            asha.brackets[1].rungs[0]["results"][trial_id], trial)

        value = 51
        fidelity = 9
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        force_observe(asha, trial)

        assert sum(len(rung["results"])
                   for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung["results"])
                   for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung["results"])
                   for rung in asha.brackets[2].rungs) == 1
        assert trial_id in asha.brackets[2].rungs[0]["results"]
        compare_registered_trial(
            asha.brackets[2].rungs[0]["results"][trial_id], trial)
示例#21
0
    def test_suggest_inf_duplicates(
        self,
        monkeypatch,
        asha: ASHA,
        bracket: ASHABracket,
        rung_0: RungDict,
        rung_1: RungDict,
        rung_2: RungDict,
    ):
        """Test that sampling inf collisions returns None."""
        asha.brackets = [bracket]
        assert bracket.owner is asha

        fidelity = 1
        zhe_trial = create_trial_for_hb((fidelity, 0.0))
        asha.trial_to_brackets[asha.get_id(zhe_trial,
                                           ignore_fidelity=True)] = 0

        def sample(num=1, seed=None):
            return [zhe_trial]

        monkeypatch.setattr(asha.space, "sample", sample)

        assert asha.suggest(1) == []
示例#22
0
    def test_promotion_with_rung_1_hit(self, asha: ASHA, bracket: ASHABracket,
                                       rung_0: RungDict):
        """Test that get_candidate gives us the next best thing if point is already in rung 1."""
        trial = create_trial_for_hb((1, 0.0), None)
        assert bracket.owner is asha
        bracket.rungs[0] = rung_0
        assert trial.objective is not None
        bracket.rungs[1]["results"][asha.get_id(trial,
                                                ignore_fidelity=True)] = (
                                                    trial.objective.value,
                                                    trial,
                                                )

        trial = bracket.get_candidates(0)[0]

        assert trial.params == create_trial_for_hb((1, 1.0), 0.0).params
示例#23
0
    def test_suggest_in_finite_cardinality(self):
        """Test that suggest None when search space is empty"""
        space = Space()
        space.register(Integer('yolo1', 'uniform', 0, 6))
        space.register(Fidelity('epoch', 1, 9, 3))

        asha = ASHA(space)
        for i in range(6):
            asha.observe([(1, i)], [{'objective': i}])

        for i in range(2):
            asha.observe([(3, i)], [{'objective': i}])

        assert asha.suggest() is None
示例#24
0
    def test_register_corrupted_db(self, caplog, space, b_config):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 3
        trial = create_trial_for_hb((fidelity, value))

        force_observe(asha, trial)
        assert "Trial registered to wrong bracket" not in caplog.text

        fidelity = 1
        trial = create_trial_for_hb((fidelity, value), objective=0.0)

        caplog.clear()
        force_observe(asha, trial)
        assert "Trial registered to wrong bracket" in caplog.text
示例#25
0
    def test_register_corrupted_db(self, caplog, space, b_config):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 3
        point = (fidelity, value)

        force_observe(asha, point, {"objective": 0.0})
        assert "Point registered to wrong bracket" not in caplog.text

        fidelity = 1
        point = [fidelity, value]

        caplog.clear()
        force_observe(asha, point, {"objective": 0.0})
        assert "Point registered to wrong bracket" in caplog.text
示例#26
0
    def test_register_invalid_fidelity(self, space, b_config):
        """Check that a point cannot registered if fidelity is invalid."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        asha.observe([trial])

        assert not asha.has_suggested(trial)
        assert not asha.has_observed(trial)
示例#27
0
    def test_get_id_multidim(self, b_config):
        """Test valid id for points with dim of shape > 1"""
        space = Space()
        space.register(Fidelity("epoch", 1, 9, 3))
        space.register(Real("lr", "uniform", 0, 1, shape=2))

        asha = ASHA(space, num_brackets=3)

        assert asha.get_id(["whatever",
                            [1, 1]]) == asha.get_id(["is here", [1, 1]])
        assert asha.get_id(["whatever", [1, 1]]) != asha.get_id(
            ["is here", [2, 2]])
示例#28
0
    def test_get_id_multidim(self, b_config):
        """Test valid id for points with dim of shape > 1"""
        space = Space()
        space.register(Fidelity('epoch', 1, 9, 3))
        space.register(Real('lr', 'uniform', 0, 1, shape=2))

        asha = ASHA(space, num_brackets=3)

        assert asha.get_id(['whatever',
                            [1, 1]]) == asha.get_id(['is here', [1, 1]])
        assert asha.get_id(['whatever', [1, 1]]) != asha.get_id(
            ['is here', [2, 2]])
示例#29
0
    def test_get_id_multidim(self, b_config):
        """Test valid id for points with dim of shape > 1"""
        space = Space()
        space.register(Fidelity('epoch'))
        space.register(Real('lr', 'uniform', 0, 1, shape=2))

        asha = ASHA(space,
                    max_resources=b_config['R'],
                    grace_period=b_config['r'],
                    reduction_factor=b_config['eta'],
                    num_brackets=3)

        assert asha.get_id(['whatever',
                            [1, 1]]) == asha.get_id(['is here', [1, 1]])
        assert asha.get_id(['whatever', [1, 1]]) != asha.get_id(
            ['is here', [2, 2]])
示例#30
0
    def test_register_corrupted_db(self, caplog, space, b_config):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 3
        point = (fidelity, value)

        asha.observe([point], [{'objective': 0.0}])
        assert 'Point registered to wrong bracket' not in caplog.text

        fidelity = 1
        point = [fidelity, value]

        caplog.clear()
        asha.observe([point], [{'objective': 0.0}])
        assert 'Point registered to wrong bracket' in caplog.text