def test_suggest_promote_identic_objectives( self, asha: ASHA, bracket: ASHABracket, big_rung_0: RungDict, big_rung_1: RungDict, ): """Test that identic objectives are handled properly""" asha.brackets = [bracket] bracket.owner = asha n_trials = 9 resources = 1 results = {} for param in np.linspace(0, 8, 9): trial = create_trial_for_hb((resources, param), objective=0) trial_hash = trial.compute_trial_hash( trial, ignore_fidelity=True, ignore_experiment=True, ) assert trial.objective is not None results[trial_hash] = (trial.objective.value, trial) bracket.rungs[0] = RungDict(n_trials=n_trials, resources=resources, results=results) candidates = asha.suggest(2) assert len(candidates) == 2 assert (sum(1 for trial in candidates if trial.params[asha.fidelity_index] == 3) == 2)
def test_suggest_promote(self, asha: ASHA, bracket: ASHABracket, rung_0: RungDict): """Test that correct point is promoted and returned.""" asha.brackets = [bracket] assert bracket.owner is asha bracket.rungs[0] = rung_0 trials = asha.suggest(1) assert trials[0].params == {"epoch": 3, "lr": 0.0}
def test_register(self, asha: ASHA, bracket: ASHABracket, rung_0: RungDict, rung_1: RungDict): """Check that a point is registered inside the bracket.""" asha.brackets = [bracket] assert bracket.owner is asha bracket.rungs = [rung_0, rung_1] trial = create_trial_for_hb((1, 0.0), 0.0) trial_id = asha.get_id(trial, ignore_fidelity=True) asha.observe([trial]) assert len(bracket.rungs[0]) assert trial_id in bracket.rungs[0]["results"] assert bracket.rungs[0]["results"][trial_id][0] == 0.0 assert bracket.rungs[0]["results"][trial_id][1].params == trial.params
def test_suggest_promote_many( self, asha: ASHA, bracket: ASHABracket, big_rung_0: RungDict, big_rung_1: RungDict, ): """Test that correct points are promoted and returned.""" asha.brackets = [bracket] assert bracket.owner is asha bracket.rungs[0] = big_rung_0 bracket.rungs[1] = big_rung_1 candidates = asha.suggest(3) assert len(candidates) == 2 + 1 assert (sum(1 for trial in candidates if trial.params[asha.fidelity_index] == 9) == 2) assert (sum(1 for trial in candidates if trial.params[asha.fidelity_index] == 3) == 1)
def test_suggest_new( self, monkeypatch, asha: ASHA, bracket: ASHABracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test that a new point is sampled.""" asha.brackets = [bracket] assert bracket.owner is asha def sample(num=1, seed=None): return [create_trial_for_hb(("fidelity", 0.5))] monkeypatch.setattr(asha.space, "sample", sample) trials = asha.suggest(1) assert trials[0].params == {"epoch": 1, "lr": 0.5}
def test_suggest_inf_duplicates( self, monkeypatch, asha: ASHA, bracket: ASHABracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test that sampling inf collisions returns None.""" asha.brackets = [bracket] assert bracket.owner is asha fidelity = 1 zhe_trial = create_trial_for_hb((fidelity, 0.0)) asha.trial_to_brackets[asha.get_id(zhe_trial, ignore_fidelity=True)] = 0 def sample(num=1, seed=None): return [zhe_trial] monkeypatch.setattr(asha.space, "sample", sample) assert asha.suggest(1) == []
def test_suggest_duplicates( self, monkeypatch, asha: ASHA, bracket: ASHABracket, rung_0: RungDict, rung_1: RungDict, rung_2: RungDict, ): """Test that sampling collisions are handled.""" asha.brackets = [bracket] assert bracket.owner is asha fidelity = 1 duplicate_trial = create_trial_for_hb((fidelity, 0.0)) new_trial = create_trial_for_hb((fidelity, 0.5)) duplicate_id_wo_fidelity = asha.get_id(duplicate_trial, ignore_fidelity=True) bracket.rungs[0] = dict( n_trials=2, resources=1, results={duplicate_id_wo_fidelity: (0.0, duplicate_trial)}, ) asha.trial_to_brackets[duplicate_id_wo_fidelity] = 0 asha.register(duplicate_trial) trials = [duplicate_trial, new_trial] def sample(num=1, seed=None): return trials monkeypatch.setattr(asha.space, "sample", sample) assert asha.suggest(1)[0].params == new_trial.params