def test_register_bracket_multi_fidelity(self, space: Space): """Check that a point is registered inside the same bracket for diff fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 1 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) force_observe(hyperband, trial) assert hyperband.brackets is not None bracket = hyperband.brackets[0] assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params fidelity = 3 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) force_observe(hyperband, trial) assert len(bracket.rungs[1]) assert trial_id in bracket.rungs[1]["results"] assert bracket.rungs[0]["results"][trial_id][1].params != trial.params assert bracket.rungs[1]["results"][trial_id][0] == 0.0 assert bracket.rungs[1]["results"][trial_id][1].params == trial.params
def test_get_id(self, space): """Test valid id of points""" hyperband = Hyperband(space) assert hyperband.get_id(['whatever', 1]) == hyperband.get_id(['is here', 1]) assert hyperband.get_id(['whatever', 1]) != hyperband.get_id(['is here', 2]) assert hyperband.get_id(['whatever', 1], ignore_fidelity=False) != \ hyperband.get_id(['is here', 1], ignore_fidelity=False) assert hyperband.get_id(['whatever', 1], ignore_fidelity=False) != \ hyperband.get_id(['is here', 2], ignore_fidelity=False) assert hyperband.get_id(['same', 1], ignore_fidelity=False) == \ hyperband.get_id(['same', 1], ignore_fidelity=False) assert hyperband.get_id(['same', 1], ignore_fidelity=False) != \ hyperband.get_id(['same', 1])
def test_suggest_duplicates_between_calls(self, monkeypatch, hyperband: Hyperband, bracket: HyperbandBracket): """Test that same trials are not allowed in different suggest call of the same hyperband execution. """ hyperband.brackets = [bracket] assert bracket.owner is hyperband fidelity = 1 duplicate_trial = create_trial_for_hb((fidelity, 0.0)) new_trial = create_trial_for_hb((fidelity, 0.5)) duplicate_id = hyperband.get_id(duplicate_trial, ignore_fidelity=True) bracket.rungs[0]["results"] = { duplicate_id: (0.0, duplicate_trial) } # type: ignore hyperband.trial_to_brackets[duplicate_id] = 0 trials = [duplicate_trial, new_trial] mock_samples( hyperband, trials + [create_trial_for_hb((fidelity, i)) for i in range(10 - 2)], ) trials = hyperband.suggest(100) assert trials is not None assert trials[0].params == new_trial.params
def test_suggest_duplicates_one_call(self, monkeypatch, hyperband: Hyperband, bracket: HyperbandBracket): """Test that same points are not allowed in the same suggest call ofxs the same hyperband execution. """ hyperband.brackets = [bracket] assert bracket.owner is hyperband zhe_point = list( map(create_trial_for_hb, [(1, 0.0), (1, 1.0), (1, 1.0), (1, 2.0)])) mock_samples(hyperband, zhe_point * 2) zhe_samples = hyperband.suggest(100) assert zhe_samples is not None assert zhe_samples[0].params["lr"] == 0.0 assert zhe_samples[1].params["lr"] == 1.0 assert zhe_samples[2].params["lr"] == 2.0 # zhe_point = mock_samples( hyperband, list( map( create_trial_for_hb, [ (3, 0.0), (3, 1.0), (3, 1.0), (3, 2.0), (3, 5.0), (3, 4.0), ], )), ) hyperband.trial_to_brackets[hyperband.get_id(create_trial_for_hb( (1, 0.0)), ignore_fidelity=True)] = 0 hyperband.trial_to_brackets[hyperband.get_id(create_trial_for_hb( (1, 0.0)), ignore_fidelity=True)] = 0 zhe_samples = hyperband.suggest(100) assert zhe_samples is not None assert zhe_samples[0].params["lr"] == 5.0 assert zhe_samples[1].params["lr"] == 4.0
def test_get_id_multidim(self): """Test valid id for points with dim of shape > 1""" space = Space() space.register(Fidelity('epoch', 1, 9, 3)) space.register(Real('lr', 'uniform', 0, 1, shape=2)) hyperband = Hyperband(space) assert hyperband.get_id(['whatever', [1, 1]]) == hyperband.get_id(['is here', [1, 1]]) assert hyperband.get_id(['whatever', [1, 1]]) != hyperband.get_id(['is here', [2, 2]]) assert hyperband.get_id(['whatever', [1, 1]], ignore_fidelity=False) != \ hyperband.get_id(['is here', [1, 1]], ignore_fidelity=False) assert hyperband.get_id(['whatever', [1, 1]], ignore_fidelity=False) != \ hyperband.get_id(['is here', [2, 2]], ignore_fidelity=False) assert hyperband.get_id(['same', [1, 1]], ignore_fidelity=False) == \ hyperband.get_id(['same', [1, 1]], ignore_fidelity=False) assert hyperband.get_id(['same', [1, 1]], ignore_fidelity=False) != \ hyperband.get_id(['same', [1, 1]])
def test_get_id(self, space): """Test valid id of points""" hyperband = Hyperband(space) assert hyperband.get_id(["whatever", 1]) == hyperband.get_id(["is here", 1]) assert hyperband.get_id(["whatever", 1]) != hyperband.get_id(["is here", 2]) assert hyperband.get_id( ["whatever", 1], ignore_fidelity=False ) != hyperband.get_id(["is here", 1], ignore_fidelity=False) assert hyperband.get_id( ["whatever", 1], ignore_fidelity=False ) != hyperband.get_id(["is here", 2], ignore_fidelity=False) assert hyperband.get_id(["same", 1], ignore_fidelity=False) == hyperband.get_id( ["same", 1], ignore_fidelity=False ) assert hyperband.get_id(["same", 1], ignore_fidelity=False) != hyperband.get_id( ["same", 1] )
def test_update_rungs_return_candidate(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_1: RungDict): """Check if a valid modified candidate is returned by update_rungs.""" assert bracket.owner is hyperband bracket.rungs[1] = rung_1 trial = create_trial_for_hb((3, 0.0), 0.0) candidates = bracket.promote(1) trial_id = hyperband.get_id(trial, ignore_fidelity=True) assert trial_id in bracket.rungs[1]["results"] assert bracket.rungs[1]["results"][trial_id][1].params == trial.params assert candidates[0].params["epoch"] == 9
def test_register_next_bracket(self, space: Space): """Check that a point is registered inside the good bracket when higher fidelity.""" hyperband = Hyperband(space) value = 50 fidelity = 3 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) force_observe(hyperband, trial) assert hyperband.brackets is not None assert sum( len(rung["results"]) for rung in hyperband.brackets[0].rungs) == 0 assert sum( len(rung["results"]) for rung in hyperband.brackets[1].rungs) == 1 assert sum( len(rung["results"]) for rung in hyperband.brackets[2].rungs) == 0 assert trial_id in hyperband.brackets[1].rungs[0]["results"] compare_registered_trial( hyperband.brackets[1].rungs[0]["results"][trial_id], trial) value = 51 fidelity = 9 trial = create_trial_for_hb((fidelity, value), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) force_observe(hyperband, trial) assert sum( len(rung["results"]) for rung in hyperband.brackets[0].rungs) == 0 assert sum( len(rung["results"]) for rung in hyperband.brackets[1].rungs) == 1 assert sum( len(rung["results"]) for rung in hyperband.brackets[2].rungs) == 1 assert trial_id in hyperband.brackets[2].rungs[0]["results"] compare_registered_trial( hyperband.brackets[2].rungs[0]["results"][trial_id], trial)
def test_register(self, hyperband: Hyperband, bracket: HyperbandBracket): """Check that a point is correctly registered inside a bracket.""" assert bracket.owner is hyperband trial = create_trial_for_hb((1, 0.0), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) bracket.register(trial) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert trial.objective is not None assert bracket.rungs[0]["results"][trial_id][ 0] == trial.objective.value assert bracket.rungs[0]["results"][trial_id][1].to_dict( ) == trial.to_dict()
def test_promotion_with_rung_1_hit(self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict): """Test that get_candidate gives us the next best thing if point is already in rung 1.""" trial = create_trial_for_hb((1, 0.0), None) assert bracket.owner is hyperband bracket.rungs[0] = rung_0 assert trial.objective is not None bracket.rungs[1]["results"][hyperband.get_id(trial, ignore_fidelity=True)] = ( trial.objective.value, trial, ) trials = bracket.get_candidates(0) assert trials[0].params == create_trial_for_hb((1, 1), 0.0).params
def test_suggest_inf_duplicates( self, monkeypatch, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test that sampling inf collisions will return None.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband zhe_trial = create_trial_for_hb(("fidelity", 0.0)) hyperband.trial_to_brackets[hyperband.get_id(zhe_trial, ignore_fidelity=True)] = 0 mock_samples(hyperband, [zhe_trial] * 2) assert hyperband.suggest(100) == []
def test_register( self, hyperband: Hyperband, bracket: HyperbandBracket, rung_0: RungDict, rung_1: RungDict, ): """Check that a point is registered inside the bracket.""" hyperband.brackets = [bracket] assert bracket.owner is hyperband bracket.rungs = [rung_0, rung_1] trial = create_trial_for_hb((1, 0.0), 0.0) trial_id = hyperband.get_id(trial, ignore_fidelity=True) hyperband.observe([trial]) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
def force_observe(hyperband: Hyperband, trial: Trial) -> None: # hyperband.sampled.add(hashlib.md5(str(list(point)).encode("utf-8")).hexdigest()) hyperband.register(trial) id_wo_fidelity = hyperband.get_id(trial, ignore_fidelity=True) bracket_index = hyperband.trial_to_brackets.get(id_wo_fidelity, None) if bracket_index is None: fidelity = flatten(trial.params)[hyperband.fidelity_index] assert hyperband.brackets is not None bracket_index = [ i for i, bracket in enumerate(hyperband.brackets) if bracket.rungs[0]["resources"] == fidelity ][0] hyperband.trial_to_brackets[id_wo_fidelity] = bracket_index hyperband.observe([trial])
def test_get_id_multidim(self): """Test valid id for points with dim of shape > 1""" space = Space() space.register(Fidelity("epoch", 1, 9, 3)) space.register(Real("lr", "uniform", 0, 1, shape=2)) hyperband = Hyperband(space) assert hyperband.get_id(["whatever", [1, 1]]) == hyperband.get_id( ["is here", [1, 1]] ) assert hyperband.get_id(["whatever", [1, 1]]) != hyperband.get_id( ["is here", [2, 2]] ) assert hyperband.get_id( ["whatever", [1, 1]], ignore_fidelity=False ) != hyperband.get_id(["is here", [1, 1]], ignore_fidelity=False) assert hyperband.get_id( ["whatever", [1, 1]], ignore_fidelity=False ) != hyperband.get_id(["is here", [2, 2]], ignore_fidelity=False) assert hyperband.get_id( ["same", [1, 1]], ignore_fidelity=False ) == hyperband.get_id(["same", [1, 1]], ignore_fidelity=False) assert hyperband.get_id( ["same", [1, 1]], ignore_fidelity=False ) != hyperband.get_id(["same", [1, 1]])