Example #1
0
    def test_register_bracket_multi_fidelity(self, space: Space):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 1
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        force_observe(hyperband, trial)
        assert hyperband.brackets is not None
        bracket = hyperband.brackets[0]

        assert len(bracket.rungs[0])
        assert trial_id in bracket.rungs[0]["results"]
        assert bracket.rungs[0]["results"][trial_id][0] == 0.0
        assert bracket.rungs[0]["results"][trial_id][1].params == trial.params

        fidelity = 3
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        force_observe(hyperband, trial)

        assert len(bracket.rungs[1])
        assert trial_id in bracket.rungs[1]["results"]
        assert bracket.rungs[0]["results"][trial_id][1].params != trial.params
        assert bracket.rungs[1]["results"][trial_id][0] == 0.0
        assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
Example #2
0
    def test_get_id_multidim(self):
        """Test valid id for points with dim of shape > 1"""
        space = Space()
        space.register(Fidelity("epoch", 1, 9, 3))
        space.register(Real("lr", "uniform", 0, 1, shape=2))

        hyperband = Hyperband(space)

        assert hyperband.get_id(["whatever", [1, 1]]) == hyperband.get_id(
            ["is here", [1, 1]]
        )
        assert hyperband.get_id(["whatever", [1, 1]]) != hyperband.get_id(
            ["is here", [2, 2]]
        )
        assert hyperband.get_id(
            ["whatever", [1, 1]], ignore_fidelity=False
        ) != hyperband.get_id(["is here", [1, 1]], ignore_fidelity=False)
        assert hyperband.get_id(
            ["whatever", [1, 1]], ignore_fidelity=False
        ) != hyperband.get_id(["is here", [2, 2]], ignore_fidelity=False)
        assert hyperband.get_id(
            ["same", [1, 1]], ignore_fidelity=False
        ) == hyperband.get_id(["same", [1, 1]], ignore_fidelity=False)
        assert hyperband.get_id(
            ["same", [1, 1]], ignore_fidelity=False
        ) != hyperband.get_id(["same", [1, 1]])
Example #3
0
    def test_register_bracket_multi_fidelity(self, space):
        """Check that a point is registered inside the same bracket for diff fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 1
        point = (fidelity, value)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        hyperband.observe([point], [{'objective': 0.0}])

        bracket = hyperband.brackets[0]

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[0]['results']
        assert (0.0, point) == bracket.rungs[0]['results'][point_hash]

        fidelity = 3
        point = [fidelity, value]
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        hyperband.observe([point], [{'objective': 0.0}])

        assert len(bracket.rungs[0])
        assert point_hash in bracket.rungs[1]['results']
        assert (0.0, point) != bracket.rungs[0]['results'][point_hash]
        assert (0.0, point) == bracket.rungs[1]['results'][point_hash]
Example #4
0
    def test_register_next_bracket(self, space):
        """Check that a point is registered inside the good bracket when higher fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 3
        point = (fidelity, value)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        hyperband.observe([point], [{'objective': 0.0}])

        assert sum(len(rung['results']) for rung in hyperband.brackets[0].rungs) == 0
        assert sum(len(rung['results']) for rung in hyperband.brackets[1].rungs) == 1
        assert sum(len(rung['results']) for rung in hyperband.brackets[2].rungs) == 0
        assert point_hash in hyperband.brackets[1].rungs[0]['results']
        assert (0.0, point) == hyperband.brackets[1].rungs[0]['results'][point_hash]

        value = 51
        fidelity = 9
        point = (fidelity, value)
        point_hash = hashlib.md5(str([value]).encode('utf-8')).hexdigest()

        hyperband.observe([point], [{'objective': 0.0}])

        assert sum(len(rung['results']) for rung in hyperband.brackets[0].rungs) == 0
        assert sum(len(rung['results']) for rung in hyperband.brackets[1].rungs) == 1
        assert sum(len(rung['results']) for rung in hyperband.brackets[2].rungs) == 1
        assert point_hash in hyperband.brackets[2].rungs[0]['results']
        assert (0.0, point) == hyperband.brackets[2].rungs[0]['results'][point_hash]
Example #5
0
    def test_suggest_in_finite_cardinality(self):
        """Test that suggest None when search space is empty"""
        space = Space()
        space.register(Integer("yolo1", "uniform", 0, 5))
        space.register(Fidelity("epoch", 1, 9, 3))

        hyperband = Hyperband(space, repetitions=1)
        for i in range(6):
            force_observe(hyperband, (1, i), {"objective": i})

        assert hyperband.suggest() is None
Example #6
0
    def test_suggest_in_finite_cardinality(self):
        """Test that suggest None when search space is empty"""
        space = Space()
        space.register(Integer('yolo1', 'uniform', 0, 6))
        space.register(Fidelity('epoch', 1, 9, 3))

        hyperband = Hyperband(space, repetitions=1)
        for i in range(6):
            hyperband.observe([(1, i)], [{'objective': i}])

        assert hyperband.suggest() is None
Example #7
0
    def test_register_invalid_fidelity(self, space):
        """Check that a point cannot registered if fidelity is invalid."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 2
        point = (fidelity, value)

        with pytest.raises(ValueError) as ex:
            force_observe(hyperband, point, {"objective": 0.0})

        assert "No bracket found for point" in str(ex.value)
Example #8
0
    def test_register_invalid_fidelity(self, space: Space):
        """Check that a point cannot registered if fidelity is invalid."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        hyperband.observe([trial])

        assert not hyperband.has_suggested(trial)
        assert not hyperband.has_observed(trial)
Example #9
0
    def test_register_invalid_fidelity(self, space):
        """Check that a point cannot registered if fidelity is invalid."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 2
        point = (fidelity, value)

        with pytest.raises(ValueError) as ex:
            hyperband.observe([point], [{'objective': 0.0}])

        assert 'No bracket found for point' in str(ex.value)
Example #10
0
    def test_register_not_sampled(self, space, caplog):
        """Check that a point cannot registered if not sampled."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 2
        trial = create_trial_for_hb((fidelity, value))

        with caplog.at_level(logging.DEBUG, logger="orion.algo.hyperband"):
            hyperband.observe([trial])

        assert len(caplog.records) == 1
        assert "Ignoring trial" in caplog.records[0].msg
Example #11
0
    def test_register_not_sampled(self, space, caplog):
        """Check that a point cannot registered if not sampled."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 2
        point = (fidelity, value)

        with caplog.at_level(logging.INFO, logger="orion.algo.hyperband"):
            hyperband.observe([point], [{"objective": 0.0}])

        assert len(caplog.records) == 1
        assert "Ignoring point" in caplog.records[0].msg
Example #12
0
    def test_get_id(self, space):
        """Test valid id of points"""
        hyperband = Hyperband(space)

        assert hyperband.get_id(['whatever', 1]) == hyperband.get_id(['is here', 1])
        assert hyperband.get_id(['whatever', 1]) != hyperband.get_id(['is here', 2])
        assert hyperband.get_id(['whatever', 1], ignore_fidelity=False) != \
            hyperband.get_id(['is here', 1], ignore_fidelity=False)
        assert hyperband.get_id(['whatever', 1], ignore_fidelity=False) != \
            hyperband.get_id(['is here', 2], ignore_fidelity=False)
        assert hyperband.get_id(['same', 1], ignore_fidelity=False) == \
            hyperband.get_id(['same', 1], ignore_fidelity=False)
        assert hyperband.get_id(['same', 1], ignore_fidelity=False) != \
            hyperband.get_id(['same', 1])
Example #13
0
    def test_register_corrupted_db(self, caplog, space):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 3
        point = (fidelity, value)

        hyperband.observe([point], [{'objective': 0.0}])
        assert 'Point registered to wrong bracket' not in caplog.text

        fidelity = 1
        point = [fidelity, value]

        caplog.clear()
        hyperband.observe([point], [{'objective': 0.0}])
        assert 'Point registered to wrong bracket' in caplog.text
Example #14
0
    def test_register_corrupted_db(self, caplog, space):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 3
        point = (fidelity, value)

        force_observe(hyperband, point, {"objective": 0.0})
        assert "Point registered to wrong bracket" not in caplog.text

        fidelity = 1
        point = [fidelity, value]

        caplog.clear()
        force_observe(hyperband, point, {"objective": 0.0})
        assert "Point registered to wrong bracket" in caplog.text
Example #15
0
    def test_register_corrupted_db(self, caplog, space):
        """Check that a point cannot registered if passed in order diff than fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 3
        trial = create_trial_for_hb((fidelity, value))

        force_observe(hyperband, trial)
        assert "Trial registered to wrong bracket" not in caplog.text

        fidelity = 1
        trial = create_trial_for_hb((fidelity, value))

        caplog.clear()
        force_observe(hyperband, trial)
        assert "Trial registered to wrong bracket" in caplog.text
Example #16
0
    def test_get_id_multidim(self):
        """Test valid id for points with dim of shape > 1"""
        space = Space()
        space.register(Fidelity('epoch', 1, 9, 3))
        space.register(Real('lr', 'uniform', 0, 1, shape=2))

        hyperband = Hyperband(space)

        assert hyperband.get_id(['whatever', [1, 1]]) == hyperband.get_id(['is here', [1, 1]])
        assert hyperband.get_id(['whatever', [1, 1]]) != hyperband.get_id(['is here', [2, 2]])
        assert hyperband.get_id(['whatever', [1, 1]], ignore_fidelity=False) != \
            hyperband.get_id(['is here', [1, 1]], ignore_fidelity=False)
        assert hyperband.get_id(['whatever', [1, 1]], ignore_fidelity=False) != \
            hyperband.get_id(['is here', [2, 2]], ignore_fidelity=False)
        assert hyperband.get_id(['same', [1, 1]], ignore_fidelity=False) == \
            hyperband.get_id(['same', [1, 1]], ignore_fidelity=False)
        assert hyperband.get_id(['same', [1, 1]], ignore_fidelity=False) != \
            hyperband.get_id(['same', [1, 1]])
Example #17
0
    def test_get_id(self, space):
        """Test valid id of points"""
        hyperband = Hyperband(space)

        assert hyperband.get_id(["whatever", 1]) == hyperband.get_id(["is here", 1])
        assert hyperband.get_id(["whatever", 1]) != hyperband.get_id(["is here", 2])
        assert hyperband.get_id(
            ["whatever", 1], ignore_fidelity=False
        ) != hyperband.get_id(["is here", 1], ignore_fidelity=False)
        assert hyperband.get_id(
            ["whatever", 1], ignore_fidelity=False
        ) != hyperband.get_id(["is here", 2], ignore_fidelity=False)
        assert hyperband.get_id(["same", 1], ignore_fidelity=False) == hyperband.get_id(
            ["same", 1], ignore_fidelity=False
        )
        assert hyperband.get_id(["same", 1], ignore_fidelity=False) != hyperband.get_id(
            ["same", 1]
        )
Example #18
0
    def test_suggest_in_finite_cardinality(self):
        """Test that suggest None when search space is empty"""
        space = Space()
        space.register(Integer("yolo1", "uniform", 0, 5))
        space.register(Fidelity("epoch", 1, 9, 3))

        hyperband = Hyperband(space, repetitions=1)
        for i in range(6):
            force_observe(
                hyperband,
                create_trial(
                    (1, i),
                    names=("epoch", "yolo1"),
                    types=("fidelity", "integer"),
                    results={"objective": i},
                ),
            )

        assert hyperband.suggest(100) == []
Example #19
0
    def test_register_next_bracket(self, space: Space):
        """Check that a point is registered inside the good bracket when higher fidelity."""
        hyperband = Hyperband(space)

        value = 50
        fidelity = 3
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        force_observe(hyperband, trial)
        assert hyperband.brackets is not None
        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[0].rungs) == 0
        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[1].rungs) == 1
        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[2].rungs) == 0
        assert trial_id in hyperband.brackets[1].rungs[0]["results"]
        compare_registered_trial(
            hyperband.brackets[1].rungs[0]["results"][trial_id], trial)

        value = 51
        fidelity = 9
        trial = create_trial_for_hb((fidelity, value), 0.0)
        trial_id = hyperband.get_id(trial, ignore_fidelity=True)

        force_observe(hyperband, trial)

        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[0].rungs) == 0
        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[1].rungs) == 1
        assert sum(
            len(rung["results"]) for rung in hyperband.brackets[2].rungs) == 1
        assert trial_id in hyperband.brackets[2].rungs[0]["results"]
        compare_registered_trial(
            hyperband.brackets[2].rungs[0]["results"][trial_id], trial)
Example #20
0
def hyperband(space):
    """Return an instance of Hyperband."""
    return Hyperband(space, repetitions=1)