Esempio n. 1
0
    def test_register_next_bracket(self, space, b_config):
        """Check that a point is registered inside the good bracket when higher fidelity."""
        asha = ASHA(space,
                    max_resources=b_config['R'],
                    grace_period=b_config['r'],
                    reduction_factor=b_config['eta'],
                    num_brackets=3)

        value = 50
        fidelity = 3
        point = (value, fidelity)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        asha.observe([point], [{'objective': 0.0}])

        assert sum(len(rung[1]) for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung[1]) for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung[1]) for rung in asha.brackets[2].rungs) == 0
        assert point_hash in asha.brackets[1].rungs[0][1]
        assert (0.0, point) == asha.brackets[1].rungs[0][1][point_hash]

        value = 51
        fidelity = 9
        point = (value, fidelity)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        asha.observe([point], [{'objective': 0.0}])

        assert sum(len(rung[1]) for rung in asha.brackets[0].rungs) == 0
        assert sum(len(rung[1]) for rung in asha.brackets[1].rungs) == 1
        assert sum(len(rung[1]) for rung in asha.brackets[2].rungs) == 1
        assert point_hash in asha.brackets[2].rungs[0][1]
        assert (0.0, point) == asha.brackets[2].rungs[0][1][point_hash]
Esempio n. 2
0
    def test_register_bracket_multi_fidelity(self, space, b_config):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 1
        point = (fidelity, value)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        asha.observe([point], [{'objective': 0.0}])

        bracket = asha.brackets[0]

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[0][1]
        assert (0.0, point) == bracket.rungs[0][1][point_hash]

        fidelity = 3
        point = [fidelity, value]
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        asha.observe([point], [{'objective': 0.0}])

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[1][1]
        assert (0.0, point) != bracket.rungs[0][1][point_hash]
        assert (0.0, point) == bracket.rungs[1][1][point_hash]
Esempio n. 3
0
    def test_register_bracket_multi_fidelity(self, space, b_config):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        asha = ASHA(space,
                    max_resources=b_config['R'],
                    grace_period=b_config['r'],
                    reduction_factor=b_config['eta'],
                    num_brackets=3)

        value = 50
        fidelity = 1
        point = (value, fidelity)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        asha.observe([point], [{'objective': 0.0}])

        bracket = asha.brackets[0]

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[0][1]
        assert (0.0, point) == bracket.rungs[0][1][point_hash]

        fidelity = 3
        point = [value, fidelity]
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        asha.observe([point], [{'objective': 0.0}])

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[1][1]
        assert (0.0, point) != bracket.rungs[0][1][point_hash]
        assert (0.0, point) == bracket.rungs[1][1][point_hash]
Esempio n. 4
0
    def test_register_invalid_fidelity(self, space, b_config):
        """Check that a point cannot registered if fidelity is invalid."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        asha.observe([trial])

        assert not asha.has_suggested(trial)
        assert not asha.has_observed(trial)
Esempio n. 5
0
    def test_register_invalid_fidelity(self, space, b_config):
        """Check that a point cannot registered if fidelity is invalid."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        point = (fidelity, value)

        with pytest.raises(ValueError) as ex:
            asha.observe([point], [{'objective': 0.0}])

        assert 'No bracket found for point' in str(ex.value)
Esempio n. 6
0
    def test_register_not_sampled(self, space, b_config, caplog):
        """Check that a point cannot registered if not sampled."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        with caplog.at_level(logging.DEBUG, logger="orion.algo.hyperband"):
            asha.observe([trial])

        assert len(caplog.records) == 1
        assert "Ignoring trial" in caplog.records[0].msg
Esempio n. 7
0
    def test_register_not_sampled(self, space, b_config, caplog):
        """Check that a point cannot registered if not sampled."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 2
        point = (fidelity, value)

        with caplog.at_level(logging.INFO, logger="orion.algo.asha"):
            asha.observe([point], [{"objective": 0.0}])

        assert len(caplog.records) == 1
        assert "Ignoring point" in caplog.records[0].msg
Esempio n. 8
0
    def test_suggest_in_finite_cardinality(self):
        """Test that suggest None when search space is empty"""
        space = Space()
        space.register(Integer('yolo1', 'uniform', 0, 6))
        space.register(Fidelity('epoch', 1, 9, 3))

        asha = ASHA(space)
        for i in range(6):
            asha.observe([(1, i)], [{'objective': i}])

        for i in range(2):
            asha.observe([(3, i)], [{'objective': i}])

        assert asha.suggest() is None
Esempio n. 9
0
    def test_register(self, asha: ASHA, bracket: ASHABracket, rung_0: RungDict,
                      rung_1: RungDict):
        """Check that a point is registered inside the bracket."""
        asha.brackets = [bracket]
        assert bracket.owner is asha
        bracket.rungs = [rung_0, rung_1]
        trial = create_trial_for_hb((1, 0.0), 0.0)
        trial_id = asha.get_id(trial, ignore_fidelity=True)

        asha.observe([trial])

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
Esempio n. 10
0
    def test_register_invalid_fidelity(self, space, b_config):
        """Check that a point cannot registered if fidelity is invalid."""
        asha = ASHA(space,
                    max_resources=b_config['R'],
                    grace_period=b_config['r'],
                    reduction_factor=b_config['eta'],
                    num_brackets=3)

        value = 50
        fidelity = 2
        point = (value, fidelity)

        with pytest.raises(ValueError) as ex:
            asha.observe([point], [{'objective': 0.0}])

        assert 'No bracket found for point' in str(ex.value)
Esempio n. 11
0
    def test_register_corrupted_db(self, caplog, space, b_config):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        asha = ASHA(space, num_brackets=3)

        value = 50
        fidelity = 3
        point = (fidelity, value)

        asha.observe([point], [{'objective': 0.0}])
        assert 'Point registered to wrong bracket' not in caplog.text

        fidelity = 1
        point = [fidelity, value]

        caplog.clear()
        asha.observe([point], [{'objective': 0.0}])
        assert 'Point registered to wrong bracket' in caplog.text
Esempio n. 12
0
    def test_register_corrupted_db(self, caplog, space, b_config):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        asha = ASHA(space,
                    max_resources=b_config['R'],
                    grace_period=b_config['r'],
                    reduction_factor=b_config['eta'],
                    num_brackets=3)

        value = 50
        fidelity = 3
        point = (value, fidelity)

        asha.observe([point], [{'objective': 0.0}])
        assert 'Point registered to wrong bracket' not in caplog.text

        fidelity = 1
        point = [value, fidelity]

        caplog.clear()
        asha.observe([point], [{'objective': 0.0}])
        assert 'Point registered to wrong bracket' in caplog.text