def create_scheduler( scheduler, **kwargs, ): """Instantiate a scheduler based on the given string. This is useful for swapping between different schedulers. Args: scheduler (str): The scheduler to use. **kwargs: Scheduler parameters. These keyword arguments will be passed to the initialization function of the chosen scheduler. Returns: ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler. Example: >>> scheduler = tune.create_scheduler('pbt', **pbt_kwargs) """ scheduler = scheduler.lower() if scheduler not in SCHEDULER_IMPORT: raise ValueError(f"The `scheduler` argument must be one of " f"{list(SCHEDULER_IMPORT)}. " f"Got: {scheduler}") SchedulerClass = SCHEDULER_IMPORT[scheduler] scheduler_args = get_function_args(SchedulerClass) trimmed_kwargs = {k: v for k, v in kwargs.items() if k in scheduler_args} return SchedulerClass(**trimmed_kwargs)
def create_searcher( search_alg, **kwargs, ): """Instantiate a search algorithm based on the given string. This is useful for swapping between different search algorithms. Args: search_alg (str): The search algorithm to use. metric (str): The training result objective value attribute. Stopping procedures will use this attribute. mode (str): One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute. **kwargs: Additional parameters. These keyword arguments will be passed to the initialization function of the chosen class. Returns: ray.tune.suggest.Searcher: The search algorithm. Example: >>> search_alg = tune.create_searcher('ax') """ search_alg = search_alg.lower() if search_alg not in SEARCH_ALG_IMPORT: raise ValueError(f"The `search_alg` argument must be one of " f"{list(SEARCH_ALG_IMPORT)}. " f"Got: {search_alg}") SearcherClass = SEARCH_ALG_IMPORT[search_alg]() search_alg_args = get_function_args(SearcherClass) trimmed_kwargs = {k: v for k, v in kwargs.items() if k in search_alg_args} return SearcherClass(**trimmed_kwargs)
def create_scheduler( scheduler, **kwargs, ): """Instantiate a scheduler based on the given string. This is useful for swapping between different schedulers. Args: scheduler (str): The scheduler to use. **kwargs: Scheduler parameters. These keyword arguments will be passed to the initialization function of the chosen scheduler. Returns: ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler. Example: >>> scheduler = tune.create_scheduler('pbt', **pbt_kwargs) """ SCHEDULER_IMPORT = { "fifo": FIFOScheduler, "async_hyperband": AsyncHyperBandScheduler, "asynchyperband": AsyncHyperBandScheduler, "median_stopping_rule": MedianStoppingRule, "medianstopping": MedianStoppingRule, "hyperband": HyperBandScheduler, "hb_bohb": HyperBandForBOHB, "pbt": PopulationBasedTraining, "pbt_replay": PopulationBasedTrainingReplay, "pb2": _pb2_importer, } scheduler = scheduler.lower() if scheduler not in SCHEDULER_IMPORT: raise ValueError( f"Search alg must be one of {list(SCHEDULER_IMPORT)}. " f"Got: {scheduler}") SchedulerClass = SCHEDULER_IMPORT[scheduler] scheduler_args = get_function_args(SchedulerClass) trimmed_kwargs = {k: v for k, v in kwargs.items() if k in scheduler_args} return SchedulerClass(**trimmed_kwargs)
def create_searcher( search_alg, **kwargs, ): """Instantiate a search algorithm based on the given string. This is useful for swapping between different search algorithms. Args: search_alg (str): The search algorithm to use. metric (str): The training result objective value attribute. Stopping procedures will use this attribute. mode (str): One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute. **kwargs: Additional parameters. These keyword arguments will be passed to the initialization function of the chosen class. Returns: ray.tune.suggest.Searcher: The search algorithm. Example: >>> search_alg = tune.create_searcher('ax') """ def _import_variant_generator(): return BasicVariantGenerator def _import_ax_search(): from ray.tune.suggest.ax import AxSearch return AxSearch def _import_dragonfly_search(): from ray.tune.suggest.dragonfly import DragonflySearch return DragonflySearch def _import_skopt_search(): from ray.tune.suggest.skopt import SkOptSearch return SkOptSearch def _import_hyperopt_search(): from ray.tune.suggest.hyperopt import HyperOptSearch return HyperOptSearch def _import_bayesopt_search(): from ray.tune.suggest.bayesopt import BayesOptSearch return BayesOptSearch def _import_bohb_search(): from ray.tune.suggest.bohb import TuneBOHB return TuneBOHB def _import_nevergrad_search(): from ray.tune.suggest.nevergrad import NevergradSearch return NevergradSearch def _import_optuna_search(): from ray.tune.suggest.optuna import OptunaSearch return OptunaSearch def _import_zoopt_search(): from ray.tune.suggest.zoopt import ZOOptSearch return ZOOptSearch def _import_sigopt_search(): from ray.tune.suggest.sigopt import SigOptSearch return SigOptSearch SEARCH_ALG_IMPORT = { "variant_generator": _import_variant_generator, "random": _import_variant_generator, "ax": _import_ax_search, "dragonfly": _import_dragonfly_search, "skopt": _import_skopt_search, "hyperopt": _import_hyperopt_search, "bayesopt": _import_bayesopt_search, "bohb": _import_bohb_search, "nevergrad": _import_nevergrad_search, "optuna": _import_optuna_search, "zoopt": _import_zoopt_search, "sigopt": _import_sigopt_search, } search_alg = search_alg.lower() if search_alg not in SEARCH_ALG_IMPORT: raise ValueError( f"Search alg must be one of {list(SEARCH_ALG_IMPORT)}. " f"Got: {search_alg}") SearcherClass = SEARCH_ALG_IMPORT[search_alg]() search_alg_args = get_function_args(SearcherClass) trimmed_kwargs = {k: v for k, v in kwargs.items() if k in search_alg_args} return SearcherClass(**trimmed_kwargs)