def test_suggest_in_finite_cardinality(self): """Test that suggest None when search space is empty""" space = Space() space.register(Integer("yolo1", "uniform", 0, 5)) space.register(Fidelity("epoch", 1, 9, 3)) asha = ASHA(space) for i in range(6): force_observe( asha, create_trial( (1, i), names=("epoch", "yolo1"), types=("fidelity", "integer"), results={"objective": i}, ), ) for i in range(2): force_observe( asha, create_trial( (3, i), names=("epoch", "yolo1"), types=("fidelity", "integer"), results={"objective": i}, ), ) assert asha.suggest(1) == []
def test_register_bracket_multi_fidelity(self, space, b_config): """Check that a point is registered inside the same bracket for diff fidelity.""" asha = ASHA(space, num_brackets=3) value = 50 fidelity = 1 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) force_observe(asha, trial) bracket = asha.brackets[0] assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params fidelity = 3 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) force_observe(asha, trial) assert len(bracket.rungs[1]) assert trial_id in bracket.rungs[1]["results"] assert bracket.rungs[0]["results"][trial_id][1].params != trial.params assert bracket.rungs[1]["results"][trial_id][0] == 0.0 assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
def test_register_corrupted_db(self, caplog, space, b_config): """Check that a point cannot registered if passed in order diff than fidelity.""" asha = ASHA(space, num_brackets=3) value = 50 fidelity = 3 trial = create_trial_for_hb((fidelity, value)) force_observe(asha, trial) assert "Trial registered to wrong bracket" not in caplog.text fidelity = 1 trial = create_trial_for_hb((fidelity, value), objective=0.0) caplog.clear() force_observe(asha, trial) assert "Trial registered to wrong bracket" in caplog.text
def test_register_next_bracket(self, space, b_config): """Check that a point is registered inside the good bracket when higher fidelity.""" asha = ASHA(space, num_brackets=3) value = 50 fidelity = 3 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) force_observe(asha, trial) assert sum(len(rung["results"]) for rung in asha.brackets[0].rungs) == 0 assert sum(len(rung["results"]) for rung in asha.brackets[1].rungs) == 1 assert sum(len(rung["results"]) for rung in asha.brackets[2].rungs) == 0 assert trial_id in asha.brackets[1].rungs[0]["results"] compare_registered_trial( asha.brackets[1].rungs[0]["results"][trial_id], trial) value = 51 fidelity = 9 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) force_observe(asha, trial) assert sum(len(rung["results"]) for rung in asha.brackets[0].rungs) == 0 assert sum(len(rung["results"]) for rung in asha.brackets[1].rungs) == 1 assert sum(len(rung["results"]) for rung in asha.brackets[2].rungs) == 1 assert trial_id in asha.brackets[2].rungs[0]["results"] compare_registered_trial( asha.brackets[2].rungs[0]["results"][trial_id], trial)