def optimize_by_continuous_asha( objective: Any, dataset: TuneDataset, plan: List[Tuple[float, int]], checkpoint_path: str = "", always_checkpoint: bool = False, study_early_stop: Optional[Callable[[List[Any], List[RungHeap]], bool]] = None, trial_early_stop: Optional[ Callable[[TrialReport, List[TrialReport], List[RungHeap]], bool] ] = None, monitor: Any = None, ) -> StudyResult: _objective = TUNE_OBJECT_FACTORY.make_iterative_objective(objective) _monitor = TUNE_OBJECT_FACTORY.make_monitor(monitor) checkpoint_path = TUNE_OBJECT_FACTORY.get_path_or_temp(checkpoint_path) judge = ASHAJudge( schedule=plan, always_checkpoint=always_checkpoint, study_early_stop=study_early_stop, trial_early_stop=trial_early_stop, monitor=_monitor, ) path = os.path.join(checkpoint_path, str(uuid4())) FileSystem().makedirs(path, recreate=True) study = IterativeStudy(_objective, checkpoint_path=path) return study.optimize(dataset, judge=judge)
def test_asha_stop(): def should_stop(keys, rungs): metrics = [] for r in rungs: if len(r) == 0: break if not r.full: return False metrics.append(r.best) if len(metrics) < 2: return False return metrics[-2] - metrics[-1] < 0.2 j = ASHAJudge( schedule=[(1.0, 2), (2.0, 2), (3.0, 1)], always_checkpoint=True, study_early_stop=should_stop, ) d = j.judge(rp("a", 0.6, 0)) assert 2.0 == d.budget assert d.should_checkpoint d = j.judge(rp("b", 0.5, 0)) assert 2.0 == d.budget assert d.should_checkpoint d = j.judge(rp("c", 0.4, 0)) assert 2.0 == d.budget assert d.should_checkpoint d = j.judge(rp("b", 0.45, 1)) assert 3.0 == d.budget assert d.should_checkpoint d = j.judge(rp("c", 0.39, 1)) assert 3.0 == d.budget assert d.should_checkpoint d = j.judge(rp("x", 0.45, 0)) # rungs[1] and rungs[0] diff so somall # no longer accept new trials assert 0.0 == d.budget assert d.should_checkpoint d = j.judge(rp("b", 0.45, 1)) assert 3.0 == d.budget # existed ids can still be accepted assert d.should_checkpoint d = j.judge(rp("a", 0.44, 2)) assert 0.0 == d.budget # already stopped assert d.should_checkpoint
def test_asha_judge_simple_happy_path(): j = ASHAJudge(schedule=[(1.0, 2), (2.0, 1)]) d = j.judge(rp("a", 0.5, 0)) assert 2.0 == d.budget assert not d.should_checkpoint d = j.judge(rp("b", 0.6, 0)) assert 2.0 == d.budget assert not d.should_checkpoint d = j.judge(rp("a", 0.4, 1)) assert d.should_stop assert d.should_checkpoint # stop criteria met, so other jobs won't get more budget d = j.judge(rp("c", 0.2, 0)) assert d.should_stop assert d.should_checkpoint
def test_trial_stop(): def should_stop(report, history, rungs): return not all(report.trial_id in x for x in rungs[:report.rung]) j = ASHAJudge( schedule=[(1.0, 2), (2.0, 2), (3.0, 1)], always_checkpoint=True, trial_early_stop=should_stop, ) d = j.judge(rp("a", 0.6, 0)) assert 2.0 == d.budget assert d.should_checkpoint d = j.judge(rp("b", 0.5, 0)) assert 2.0 == d.budget assert d.should_checkpoint d = j.judge(rp("c", 0.4, 0)) assert 2.0 == d.budget assert d.should_checkpoint d = j.judge(rp("a", 0.1, 1)) assert d.should_stop # kicked out by c assert d.should_checkpoint