def test_split(args: List[str], max_batch_size: Optional[int], expected: List[List[List[str]]]) -> None: parser = OverridesParser.create() ret = BasicSweeper.split_arguments(parser.parse_overrides(args), max_batch_size=max_batch_size) lret = [list(x) for x in ret] assert lret == expected
def sweep(self, arguments: List[str]) -> Any: assert self.config is not None assert self.launcher is not None parser = OverridesParser.create(config_loader=self.config_loader) overrides = parser.parse_overrides(arguments) self.overrides = self.split_arguments(overrides, self.max_batch_size) returns: List[Sequence[JobReturn]] = [] # Save sweep run config in top level sweep working directory sweep_dir = Path(self.config.hydra.sweep.dir) sweep_dir.mkdir(parents=True, exist_ok=True) OmegaConf.save(self.config, sweep_dir / "multirun.yaml") initial_job_idx = 0 while not self.is_done(): batch = self.get_job_batch() tic = time.perf_counter() # Validate that jobs can be safely composed. This catches composition errors early. # This can be a bit slow for large jobs. can potentially allow disabling from the config. self.validate_batch_is_legal(batch) elapsed = time.perf_counter() - tic log.debug( f"Validated configs of {len(batch)} jobs in {elapsed:0.2f} seconds, {len(batch)/elapsed:.2f} / second)" ) results = self.launcher.launch(batch, initial_job_idx=initial_job_idx) initial_job_idx += len(batch) returns.append(results) return returns
def sweep(self, arguments: List[str]) -> Any: assert self.config is not None assert self.launcher is not None log.info(f"ExampleSweeper (foo={self.foo}, bar={self.bar}) sweeping") log.info(f"Sweep output dir : {self.config.hydra.sweep.dir}") # Save sweep run config in top level sweep working directory sweep_dir = Path(self.config.hydra.sweep.dir) sweep_dir.mkdir(parents=True, exist_ok=True) OmegaConf.save(self.config, sweep_dir / "multirun.yaml") parser = OverridesParser.create() parsed = parser.parse_overrides(arguments) lists = [] for override in parsed: if override.is_sweep_override(): # Sweepers must manipulate only overrides that return true to is_sweep_override() # This syntax is shared across all sweepers, so it may limiting. # Sweeper must respect this though: failing to do so will cause all sorts of hard to debug issues. # If you would like to propose an extension to the grammar (enabling new types of sweep overrides) # Please file an issue and describe the use case and the proposed syntax. # Be aware that syntax extensions are potentially breaking compatibility for existing users and the # use case will be scrutinized heavily before the syntax is changed. sweep_choices = override.sweep_string_iterator() key = override.get_key_element() sweep = [f"{key}={val}" for val in sweep_choices] lists.append(sweep) else: key = override.get_key_element() value = override.get_value_element_as_str() lists.append([f"{key}={value}"]) batches = list(itertools.product(*lists)) # some sweepers will launch multiple batches. # for such sweepers, it is important that they pass the proper initial_job_idx when launching # each batch. see example below. # This is required to ensure that working that the job gets a unique job id # (which in turn can be used for other things, like the working directory) def chunks(lst: Sequence[Sequence[str]], n: Optional[int]) -> Iterable[Sequence[Sequence[str]]]: """ Split input to chunks of up to n items each """ if n is None or n == -1: n = len(lst) for i in range(0, len(lst), n): yield lst[i:i + n] chunked_batches = list(chunks(batches, self.max_batch_size)) returns = [] initial_job_idx = 0 for batch in chunked_batches: self.validate_batch_is_legal(batch) results = self.launcher.launch(batch, initial_job_idx=initial_job_idx) initial_job_idx += len(batch) returns.append(results) return returns
def test_create_nevergrad_parameter_from_override( input: Any, expected: Any, ) -> None: parser = OverridesParser.create() parsed = parser.parse_overrides([input])[0] param = _impl.create_nevergrad_parameter_from_override(parsed) assert_ng_param_equals(param, expected)
def sweep(self, arguments: List[str]) -> None: parser = OverridesParser.create() parsed = parser.parse_overrides(arguments) search_space = {} for override in parsed: search_space[override.get_key_element( )] = create_optuna_distribution_from_override(override) study = optuna.create_study(study_name=self.optuna_config.study_name, storage=self.optuna_config.storage, direction=self.optuna_config.direction) batch_size = self.optuna_config.n_jobs n_trials_to_go = self.optuna_config.n_trials while n_trials_to_go > 0: batch_size = min(n_trials_to_go, batch_size) trials = [study._ask() for _ in range(batch_size)] for trial in trials: for param_name, distribution in search_space.items(): trial._suggest(param_name, distribution) overrides = [] for trial in trials: params = trial.params overrides.append( tuple(f"{name}={val}" for name, val in params.items())) returns = self.launcher.launch(overrides, initial_job_idx=trials[0].number) for trial, ret in zip(trials, returns): study._tell(trial, optuna.trial.TrialState.COMPLETE, ret.return_value) n_trials_to_go -= batch_size best_trial = study.best_trial results_to_serialize = { "name": "optuna", "best_params": best_trial.params, "best_value": best_trial.value, } OmegaConf.save( OmegaConf.create(results_to_serialize), f"{self.config.hydra.sweep.dir}/optimization_results.yaml", ) log.info(f"Best parameters: {best_trial.params}") log.info(f"Best value: {best_trial.value}") log.info(f"Storage: {self.optuna_config.storage}") log.info(f"Study name: {study.study_name}")
def test_apply_overrides_to_config(input_cfg: Any, overrides: List[str], expected: Any) -> None: cfg = OmegaConf.create(input_cfg) OmegaConf.set_struct(cfg, True) parser = OverridesParser.create() parsed = parser.parse_overrides(overrides=overrides) if isinstance(expected, dict): ConfigLoaderImpl._apply_overrides_to_config(overrides=parsed, cfg=cfg) assert cfg == expected else: with expected: ConfigLoaderImpl._apply_overrides_to_config(overrides=parsed, cfg=cfg)
def compute_defaults_list( self, config_name: Optional[str], overrides: List[str], run_mode: RunMode, ) -> DefaultsList: parser = OverridesParser.create() repo = CachingConfigRepository(self.repository) defaults_list = create_defaults_list( repo=repo, config_name=config_name, overrides_list=parser.parse_overrides(overrides=overrides), prepend_hydra=True, skip_missing=run_mode == RunMode.MULTIRUN, ) return defaults_list
def test_delete_by_assigning_null_is_deprecated() -> None: msg = ("\nRemoving from the defaults list by assigning 'null' " "is deprecated and will be removed in Hydra 1.1." "\nUse ~db") defaults = ConfigLoaderImpl._parse_defaults( OmegaConf.create({"defaults": [{ "db": "mysql" }]})) parser = OverridesParser.create() override = parser.parse_override("db=null") with pytest.warns(expected_warning=UserWarning, match=re.escape(msg)): ConfigLoaderImpl._apply_overrides_to_defaults(overrides=[override], defaults=defaults) assert defaults == []
def test_apply_overrides_to_defaults(input_defaults: List[str], overrides: List[str], expected: Any) -> None: defaults = ConfigLoaderImpl._parse_defaults( OmegaConf.create({"defaults": input_defaults})) parser = OverridesParser.create() if isinstance(expected, list): parsed_overrides = parser.parse_overrides(overrides=overrides) expected_defaults = ConfigLoaderImpl._parse_defaults( OmegaConf.create({"defaults": expected})) ConfigLoaderImpl._apply_overrides_to_defaults( overrides=parsed_overrides, defaults=defaults) assert defaults == expected_defaults else: with expected: parsed_overrides = parser.parse_overrides(overrides=overrides) ConfigLoaderImpl._apply_overrides_to_defaults( overrides=parsed_overrides, defaults=defaults)
def _test_defaults_tree_impl( config_name: Optional[str], input_overrides: List[str], expected: Any, prepend_hydra: bool = False, skip_missing: bool = False, ) -> Optional[DefaultsList]: parser = OverridesParser.create() repo = create_repo() root = _create_root(config_name=config_name, with_hydra=prepend_hydra) overrides_list = parser.parse_overrides(overrides=input_overrides) overrides = Overrides(repo=repo, overrides_list=overrides_list) if expected is None or isinstance(expected, DefaultsTreeNode): result = _create_defaults_tree( repo=repo, root=root, overrides=overrides, is_root_config=True, interpolated_subtree=False, skip_missing=skip_missing, ) overrides.ensure_overrides_used() overrides.ensure_deletions_used() assert result == expected return DefaultsList(defaults=[], defaults_tree=result, overrides=overrides, config_overrides=[]) else: with expected: _create_defaults_tree( repo=repo, root=root, overrides=overrides, is_root_config=True, interpolated_subtree=False, skip_missing=skip_missing, ) overrides.ensure_overrides_used() overrides.ensure_deletions_used() return None
def parse_commandline_args( self, arguments: List[str] ) -> List[Dict[str, Union[ax_types.TParamValue, List[ax_types.TParamValue]]]]: """Method to parse the command line arguments and convert them into Ax parameters""" parser = OverridesParser.create() parsed = parser.parse_overrides(arguments) parameters: List[Dict[str, Any]] = [] for override in parsed: if override.is_sweep_override(): if override.is_choice_sweep(): param = create_choice_param_from_choice_override(override) elif override.is_range_sweep(): param = create_choice_param_from_range_override(override) elif override.is_interval_sweep(): param = create_range_param_using_interval_override(override) elif not override.is_hydra_override(): param = create_fixed_param_from_element_override(override) parameters.append(param) return parameters
def parse_overrides( overrides: List[str], run_mode: RunMode, from_shell: bool, ) -> List[Override]: parser = OverridesParser.create() parsed_overrides = parser.parse_overrides(overrides=overrides) config_overrides = [] for x in parsed_overrides: if x.is_sweep_override(): if run_mode == RunMode.MULTIRUN: if x.is_hydra_override(): raise ConfigCompositionException( f"Sweeping over Hydra's configuration is not supported : '{x.input_line}'" ) # do not process sweep overrides in multirun mode. # They will be handled directly by the sweeper elif run_mode == RunMode.RUN: if x.value_type == ValueType.SIMPLE_CHOICE_SWEEP: vals = "value1,value2" if from_shell: example_override = f"key=\\'{vals}\\'" else: example_override = f"key='{vals}'" msg = dedent(f"""\ Ambiguous value for argument '{x.input_line}' 1. To use it as a list, use key=[value1,value2] 2. To use it as string, quote the value: {example_override} 3. To sweep over it, add --multirun to your command line""" ) raise ConfigCompositionException(msg) else: raise ConfigCompositionException( f"Sweep parameters '{x.input_line}' requires --multirun" ) else: assert False else: config_overrides.append(x) return config_overrides
def apply_overrides(cfg, overrides: List[str]): """ In-place override contents of cfg. Args: cfg: an omegaconf config object overrides: list of strings in the format of "a=b" to override configs. See https://hydra.cc/docs/next/advanced/override_grammar/basic/ for syntax. Returns: the cfg object """ def safe_update(cfg, key, value): parts = key.split(".") for idx in range(1, len(parts)): prefix = ".".join(parts[:idx]) v = OmegaConf.select(cfg, prefix, default=None) if v is None: break if not OmegaConf.is_config(v): raise KeyError( f"Trying to update key {key}, but {prefix} " f"is not a config, but has type {type(v)}." ) OmegaConf.update(cfg, key, value, merge=True) from hydra.core.override_parser.overrides_parser import OverridesParser parser = OverridesParser.create() overrides = parser.parse_overrides(overrides) for o in overrides: key = o.key_or_group value = o.value() if o.is_delete(): # TODO support this raise NotImplementedError("deletion is not yet a supported override") safe_update(cfg, key, value) return cfg
def apply_overrides(cfg, overrides: List[str]): """ In-place override contents of cfg. Args: cfg: an omegaconf config object overrides: list of strings in the format of "a=b" to override configs. See https://hydra.cc/docs/next/advanced/override_grammar/basic/ for syntax. Returns: the cfg object """ from hydra.core.override_parser.overrides_parser import OverridesParser parser = OverridesParser.create() overrides = parser.parse_overrides(overrides) for o in overrides: key = o.key_or_group value = o.value() # TODO seems nice to support this assert not o.is_delete(), "deletion is not a supported override" OmegaConf.update(cfg, key, value, merge=True) return cfg
def sweep(self, arguments: List[str]) -> None: assert self.config is not None assert self.launcher is not None assert self.hydra_context is not None assert self.job_idx is not None parser = OverridesParser.create() parsed = parser.parse_overrides(arguments) search_space = dict(self.search_space) fixed_params = dict() for override in parsed: value = create_optuna_distribution_from_override(override) if isinstance(value, BaseDistribution): search_space[override.get_key_element()] = value else: fixed_params[override.get_key_element()] = value # Remove fixed parameters from Optuna search space. for param_name in fixed_params: if param_name in search_space: del search_space[param_name] directions: List[str] if isinstance(self.direction, MutableSequence): directions = [ d.name if isinstance(d, Direction) else d for d in self.direction ] else: if isinstance(self.direction, str): directions = [self.direction] else: directions = [self.direction.name] study = optuna.create_study( study_name=self.study_name, storage=self.storage, sampler=self.sampler, directions=directions, load_if_exists=True, ) log.info(f"Study name: {study.study_name}") log.info(f"Storage: {self.storage}") log.info(f"Sampler: {type(self.sampler).__name__}") log.info(f"Directions: {directions}") batch_size = self.n_jobs n_trials_to_go = self.n_trials while n_trials_to_go > 0: batch_size = min(n_trials_to_go, batch_size) trials = [study._ask() for _ in range(batch_size)] overrides = [] for trial in trials: for param_name, distribution in search_space.items(): trial._suggest(param_name, distribution) params = dict(trial.params) params.update(fixed_params) overrides.append( tuple(f"{name}={val}" for name, val in params.items())) returns = self.launcher.launch(overrides, initial_job_idx=self.job_idx) self.job_idx += len(returns) for trial, ret in zip(trials, returns): values: Optional[List[float]] = None state: optuna.trial.TrialState = optuna.trial.TrialState.COMPLETE try: if len(directions) == 1: try: values = [float(ret.return_value)] except (ValueError, TypeError): raise ValueError( f"Return value must be float-castable. Got '{ret.return_value}'." ).with_traceback(sys.exc_info()[2]) else: try: values = [float(v) for v in ret.return_value] except (ValueError, TypeError): raise ValueError( "Return value must be a list or tuple of float-castable values." f" Got '{ret.return_value}'.").with_traceback( sys.exc_info()[2]) if len(values) != len(directions): raise ValueError( "The number of the values and the number of the objectives are" f" mismatched. Expect {len(directions)}, but actually {len(values)}." ) study._tell(trial, state, values) except Exception as e: state = optuna.trial.TrialState.FAIL study._tell(trial, state, values) raise e n_trials_to_go -= batch_size results_to_serialize: Dict[str, Any] if len(directions) < 2: best_trial = study.best_trial results_to_serialize = { "name": "optuna", "best_params": best_trial.params, "best_value": best_trial.value, } log.info(f"Best parameters: {best_trial.params}") log.info(f"Best value: {best_trial.value}") else: best_trials = study.best_trials pareto_front = [{ "params": t.params, "values": t.values } for t in best_trials] results_to_serialize = { "name": "optuna", "solutions": pareto_front, } log.info(f"Number of Pareto solutions: {len(best_trials)}") for t in best_trials: log.info(f" Values: {t.values}, Params: {t.params}") OmegaConf.save( OmegaConf.create(results_to_serialize), f"{self.config.hydra.sweep.dir}/optimization_results.yaml", )
def test_create_optuna_distribution_from_override(input: Any, expected: Any) -> None: parser = OverridesParser.create() parsed = parser.parse_overrides([input])[0] actual = _impl.create_optuna_distribution_from_override(parsed) check_distribution(expected, actual)
def _load_configuration_impl( self, config_name: Optional[str], overrides: List[str], run_mode: RunMode, from_shell: bool = True, ) -> DictConfig: from hydra import __version__ self.ensure_main_config_source_available() caching_repo = CachingConfigRepository(self.repository) parser = OverridesParser.create() parsed_overrides = parser.parse_overrides(overrides=overrides) self._process_config_searchpath(config_name, parsed_overrides, caching_repo) self.validate_sweep_overrides_legal(overrides=parsed_overrides, run_mode=run_mode, from_shell=from_shell) defaults_list = create_defaults_list( repo=caching_repo, config_name=config_name, overrides_list=parsed_overrides, prepend_hydra=True, skip_missing=run_mode == RunMode.MULTIRUN, ) config_overrides = defaults_list.config_overrides cfg = self._compose_config_from_defaults_list( defaults=defaults_list.defaults, repo=caching_repo) # Set config root to struct mode. # Note that this will close any dictionaries (including dicts annotated as Dict[K, V]. # One must use + to add new fields to them. OmegaConf.set_struct(cfg, True) # The Hydra node should not be read-only even if the root config is read-only. OmegaConf.set_readonly(cfg.hydra, False) # Apply command line overrides after enabling strict flag ConfigLoaderImpl._apply_overrides_to_config(config_overrides, cfg) app_overrides = [] for override in parsed_overrides: if override.is_hydra_override(): cfg.hydra.overrides.hydra.append(override.input_line) else: cfg.hydra.overrides.task.append(override.input_line) app_overrides.append(override) with open_dict(cfg.hydra): cfg.hydra.runtime.choices.update( defaults_list.overrides.known_choices) for key in cfg.hydra.job.env_copy: cfg.hydra.job.env_set[key] = os.environ[key] cfg.hydra.runtime.version = __version__ cfg.hydra.runtime.cwd = os.getcwd() cfg.hydra.runtime.config_sources = [ ConfigSourceInfo(path=x.path, schema=x.scheme(), provider=x.provider) for x in caching_repo.get_sources() ] if "name" not in cfg.hydra.job: cfg.hydra.job.name = JobRuntime().get("name") cfg.hydra.job.override_dirname = get_overrides_dirname( overrides=app_overrides, kv_sep=cfg.hydra.job.config.override_dirname.kv_sep, item_sep=cfg.hydra.job.config.override_dirname.item_sep, exclude_keys=cfg.hydra.job.config.override_dirname.exclude_keys, ) cfg.hydra.job.config_name = config_name return cfg
def _load_configuration_impl( self, config_name: Optional[str], overrides: List[str], run_mode: RunMode, from_shell: bool = True, ) -> DictConfig: self.ensure_main_config_source_available() caching_repo = CachingConfigRepository(self.repository) parser = OverridesParser.create() parsed_overrides = parser.parse_overrides(overrides=overrides) self.validate_sweep_overrides_legal(overrides=parsed_overrides, run_mode=run_mode, from_shell=from_shell) defaults_list = create_defaults_list( repo=caching_repo, config_name=config_name, overrides_list=parser.parse_overrides(overrides=overrides), prepend_hydra=True, skip_missing=run_mode == RunMode.MULTIRUN, ) config_overrides = defaults_list.config_overrides cfg = self._compose_config_from_defaults_list( defaults=defaults_list.defaults, repo=caching_repo) # set struct mode on user config if the root node is not typed. if OmegaConf.get_type(cfg) is dict: OmegaConf.set_struct(cfg, True) # Turn off struct mode on the Hydra node. It gets its type safety from it being a Structured Config. # This enables adding fields to nested dicts like hydra.job.env_set without having to using + to append. OmegaConf.set_struct(cfg.hydra, False) # Make the Hydra node writeable (The user may have made the primary node read-only). OmegaConf.set_readonly(cfg.hydra, False) # Apply command line overrides after enabling strict flag ConfigLoaderImpl._apply_overrides_to_config(config_overrides, cfg) app_overrides = [] for override in parsed_overrides: if override.is_hydra_override(): cfg.hydra.overrides.hydra.append(override.input_line) else: cfg.hydra.overrides.task.append(override.input_line) app_overrides.append(override) cfg.hydra.choices.update(defaults_list.overrides.known_choices) with open_dict(cfg.hydra): from hydra import __version__ cfg.hydra.runtime.version = __version__ cfg.hydra.runtime.cwd = os.getcwd() if "name" not in cfg.hydra.job: cfg.hydra.job.name = JobRuntime().get("name") cfg.hydra.job.override_dirname = get_overrides_dirname( overrides=app_overrides, kv_sep=cfg.hydra.job.config.override_dirname.kv_sep, item_sep=cfg.hydra.job.config.override_dirname.item_sep, exclude_keys=cfg.hydra.job.config.override_dirname. exclude_keys, ) cfg.hydra.job.config_name = config_name for key in cfg.hydra.job.env_copy: cfg.hydra.job.env_set[key] = os.environ[key] return cfg
def sweep(self, arguments: List[str]) -> None: assert self.config is not None assert self.launcher is not None assert self.job_idx is not None direction = -1 if self.opt_config.maximize else 1 name = "maximization" if self.opt_config.maximize else "minimization" # Override the parametrization from commandline params = dict(self.parametrization) parser = OverridesParser.create() parsed = parser.parse_overrides(arguments) for override in parsed: params[ override.get_key_element() ] = create_nevergrad_parameter_from_override(override) parametrization = ng.p.Dict(**params) parametrization.function.deterministic = not self.opt_config.noisy parametrization.random_state.seed(self.opt_config.seed) # log and build the optimizer opt = self.opt_config.optimizer remaining_budget = self.opt_config.budget nw = self.opt_config.num_workers log.info( f"NevergradSweeper(optimizer={opt}, budget={remaining_budget}, " f"num_workers={nw}) {name}" ) log.info(f"with parametrization {parametrization}") log.info(f"Sweep output dir: {self.config.hydra.sweep.dir}") optimizer = ng.optimizers.registry[opt](parametrization, remaining_budget, nw) # loop! all_returns: List[Any] = [] best: Tuple[float, ng.p.Parameter] = (float("inf"), parametrization) while remaining_budget > 0: batch = min(nw, remaining_budget) remaining_budget -= batch candidates = [optimizer.ask() for _ in range(batch)] overrides = list( tuple(f"{x}={y}" for x, y in c.value.items()) for c in candidates ) self.validate_batch_is_legal(overrides) returns = self.launcher.launch(overrides, initial_job_idx=self.job_idx) self.job_idx += len(returns) # would have been nice to avoid waiting for all jobs to finish # aka batch size Vs steady state (launching a new job whenever one is done) for cand, ret in zip(candidates, returns): loss = direction * ret.return_value optimizer.tell(cand, loss) if loss < best[0]: best = (loss, cand) all_returns.extend(returns) recom = optimizer.provide_recommendation() results_to_serialize = { "name": "nevergrad", "best_evaluated_params": best[1].value, "best_evaluated_result": direction * best[0], } OmegaConf.save( OmegaConf.create(results_to_serialize), f"{self.config.hydra.sweep.dir}/optimization_results.yaml", ) log.info( "Best parameters: %s", " ".join(f"{x}={y}" for x, y in recom.value.items()) )
def _load_configuration_impl( self, config_name: Optional[str], overrides: List[str], run_mode: RunMode, strict: Optional[bool] = None, from_shell: bool = True, ) -> DictConfig: self.ensure_main_config_source_available() caching_repo = CachingConfigRepository(self.repository) if strict is None: strict = self.default_strict parser = OverridesParser.create() parsed_overrides = parser.parse_overrides(overrides=overrides) self.validate_sweep_overrides_legal( overrides=parsed_overrides, run_mode=run_mode, from_shell=from_shell ) defaults_list = create_defaults_list( repo=caching_repo, config_name=config_name, overrides_list=parser.parse_overrides(overrides=overrides), prepend_hydra=True, skip_missing=run_mode == RunMode.MULTIRUN, ) config_overrides = defaults_list.config_overrides cfg, composition_trace = self._compose_config_from_defaults_list( defaults=defaults_list.defaults, repo=caching_repo ) OmegaConf.set_struct(cfg, strict) OmegaConf.set_readonly(cfg.hydra, False) # Apply command line overrides after enabling strict flag ConfigLoaderImpl._apply_overrides_to_config(config_overrides, cfg) app_overrides = [] for override in parsed_overrides: if override.is_hydra_override(): cfg.hydra.overrides.hydra.append(override.input_line) else: cfg.hydra.overrides.task.append(override.input_line) app_overrides.append(override) # TODO: should this open_dict be required given that choices is a Dict? with open_dict(cfg.hydra.choices): cfg.hydra.choices.update(defaults_list.overrides.known_choices) with open_dict(cfg.hydra): from hydra import __version__ cfg.hydra.runtime.version = __version__ cfg.hydra.runtime.cwd = os.getcwd() cfg.hydra.composition_trace = composition_trace if "name" not in cfg.hydra.job: cfg.hydra.job.name = JobRuntime().get("name") cfg.hydra.job.override_dirname = get_overrides_dirname( overrides=app_overrides, kv_sep=cfg.hydra.job.config.override_dirname.kv_sep, item_sep=cfg.hydra.job.config.override_dirname.item_sep, exclude_keys=cfg.hydra.job.config.override_dirname.exclude_keys, ) cfg.hydra.job.config_name = config_name for key in cfg.hydra.job.env_copy: cfg.hydra.job.env_set[key] = os.environ[key] return cfg
def _load_configuration_impl( self, config_name: Optional[str], overrides: List[str], run_mode: RunMode, strict: Optional[bool] = None, from_shell: bool = True, ) -> DictConfig: self.ensure_main_config_source_available() caching_repo = CachingConfigRepository(self.repository) if config_name is not None and not caching_repo.config_exists( config_name): self._missing_config_error( config_name=config_name, msg= f"Cannot find primary config : {config_name}, check that it's in your config search path", with_search_path=True, ) if strict is None: strict = self.default_strict parser = OverridesParser.create() parsed_overrides = parser.parse_overrides(overrides=overrides) config_overrides = ConfigLoaderImpl.parse_overrides( overrides=overrides, run_mode=run_mode, from_shell=from_shell) split_res = self.split_by_override_type(config_overrides) config_group_overrides = split_res.config_group_overrides config_overrides = split_res.config_overrides input_defaults = [DefaultElement(config_name="hydra_config")] if config_name is not None: input_defaults.append( DefaultElement(config_name=config_name, primary=True)) for default in convert_overrides_to_defaults(config_group_overrides): input_defaults.append(default) skip_missing = run_mode == RunMode.MULTIRUN defaults = expand_defaults_list( defaults=input_defaults, skip_missing=skip_missing, repo=caching_repo, ) cfg, composition_trace = self._compose_config_from_defaults_list( defaults=defaults, repo=caching_repo) OmegaConf.set_struct(cfg, strict) OmegaConf.set_readonly(cfg.hydra, False) # Apply command line overrides after enabling strict flag ConfigLoaderImpl._apply_overrides_to_config(config_overrides, cfg) app_overrides = [] for override in parsed_overrides: if override.is_hydra_override(): cfg.hydra.overrides.hydra.append(override.input_line) else: cfg.hydra.overrides.task.append(override.input_line) app_overrides.append(override) with open_dict(cfg.hydra): from hydra import __version__ cfg.hydra.runtime.version = __version__ cfg.hydra.runtime.cwd = os.getcwd() cfg.hydra.composition_trace = composition_trace if "name" not in cfg.hydra.job: cfg.hydra.job.name = JobRuntime().get("name") cfg.hydra.job.override_dirname = get_overrides_dirname( overrides=app_overrides, kv_sep=cfg.hydra.job.config.override_dirname.kv_sep, item_sep=cfg.hydra.job.config.override_dirname.item_sep, exclude_keys=cfg.hydra.job.config.override_dirname. exclude_keys, ) cfg.hydra.job.config_name = config_name for key in cfg.hydra.job.env_copy: cfg.hydra.job.env_set[key] = os.environ[key] return cfg
def sweep(self, arguments: List[str]) -> None: assert self.config is not None assert self.launcher is not None assert self.job_idx is not None parser = OverridesParser.create() parsed = parser.parse_overrides(arguments) search_space = dict(self.search_space) fixed_params = dict() for override in parsed: value = create_optuna_distribution_from_override(override) if isinstance(value, BaseDistribution): search_space[override.get_key_element()] = value else: fixed_params[override.get_key_element()] = value # Remove fixed parameters from Optuna search space. for param_name in fixed_params: if param_name in search_space: del search_space[param_name] samplers = { "tpe": "optuna.samplers.TPESampler", "random": "optuna.samplers.RandomSampler", "cmaes": "optuna.samplers.CmaEsSampler", } if self.optuna_config.sampler.name not in samplers: raise NotImplementedError( f"{self.optuna_config.sampler} is not supported by Optuna sweeper." ) sampler_class = get_class(samplers[self.optuna_config.sampler.name]) sampler = sampler_class(seed=self.optuna_config.seed) # TODO (toshihikoyanase): Remove type-ignore when optuna==2.4.0 is released. study = optuna.create_study( # type: ignore study_name=self.optuna_config.study_name, storage=self.optuna_config.storage, sampler=sampler, direction=self.optuna_config.direction.name, load_if_exists=True, ) log.info(f"Study name: {study.study_name}") log.info(f"Storage: {self.optuna_config.storage}") log.info(f"Sampler: {self.optuna_config.sampler.name}") log.info(f"Direction: {self.optuna_config.direction.name}") batch_size = self.optuna_config.n_jobs n_trials_to_go = self.optuna_config.n_trials while n_trials_to_go > 0: batch_size = min(n_trials_to_go, batch_size) trials = [study._ask() for _ in range(batch_size)] overrides = [] for trial in trials: for param_name, distribution in search_space.items(): trial._suggest(param_name, distribution) params = dict(trial.params) params.update(fixed_params) overrides.append( tuple(f"{name}={val}" for name, val in params.items())) returns = self.launcher.launch(overrides, initial_job_idx=self.job_idx) self.job_idx += len(returns) for trial, ret in zip(trials, returns): # TODO (toshihikoyanase): Remove type-ignore when optuna==2.4.0 is released. study._tell(trial, optuna.trial.TrialState.COMPLETE, ret.return_value) # type: ignore n_trials_to_go -= batch_size best_trial = study.best_trial results_to_serialize = { "name": "optuna", "best_params": best_trial.params, "best_value": best_trial.value, } OmegaConf.save( OmegaConf.create(results_to_serialize), f"{self.config.hydra.sweep.dir}/optimization_results.yaml", ) log.info(f"Best parameters: {best_trial.params}") log.info(f"Best value: {best_trial.value}")
def _load_configuration( self, config_name: Optional[str], overrides: List[str], run_mode: RunMode, strict: Optional[bool] = None, from_shell: bool = True, ) -> DictConfig: if config_name is not None and not self.repository.config_exists(config_name): self.missing_config_error( config_name=config_name, msg=f"Cannot find primary config : {config_name}, check that it's in your config search path", with_search_path=True, ) if strict is None: strict = self.default_strict parser = OverridesParser.create() parsed_overrides = parser.parse_overrides(overrides=overrides) config_overrides = [] sweep_overrides = [] for x in parsed_overrides: if x.is_sweep_override(): if run_mode == RunMode.MULTIRUN: if x.is_hydra_override(): raise ConfigCompositionException( f"Sweeping over Hydra's configuration is not supported : '{x.input_line}'" ) sweep_overrides.append(x) elif run_mode == RunMode.RUN: if x.value_type == ValueType.SIMPLE_CHOICE_SWEEP: vals = "value1,value2" if from_shell: example_override = f"key=\\'{vals}\\'" else: example_override = f"key='{vals}'" msg = f"""Ambiguous value for argument '{x.input_line}' 1. To use it as a list, use key=[value1,value2] 2. To use it as string, quote the value: {example_override} 3. To sweep over it, add --multirun to your command line""" raise ConfigCompositionException(msg) else: raise ConfigCompositionException( f"Sweep parameters '{x.input_line}' requires --multirun" ) else: assert False else: config_overrides.append(x) config_group_overrides, config_overrides = self.split_by_override_type( config_overrides ) # Load hydra config hydra_cfg, _load_trace = self._load_primary_config(cfg_filename="hydra_config") # Load job config job_cfg, job_cfg_load_trace = self._load_primary_config( cfg_filename=config_name, record_load=False ) job_defaults = self._parse_defaults(job_cfg) defaults = self._parse_defaults(hydra_cfg) job_cfg_type = OmegaConf.get_type(job_cfg) if job_cfg_type is not None and not issubclass(job_cfg_type, dict): hydra_cfg._promote(job_cfg_type) # during the regular merge later the config will retain the readonly flag. _recursive_unset_readonly(hydra_cfg) # this is breaking encapsulation a bit. can potentially be implemented in OmegaConf hydra_cfg._metadata.ref_type = job_cfg._metadata.ref_type OmegaConf.set_readonly(hydra_cfg.hydra, False) # if defaults are re-introduced by the promotion, remove it. if "defaults" in hydra_cfg: with open_dict(hydra_cfg): del hydra_cfg["defaults"] if config_name is not None: defaults.append(DefaultElement(config_group=None, config_name="__SELF__")) split_at = len(defaults) self._combine_default_lists(defaults, job_defaults) ConfigLoaderImpl._apply_overrides_to_defaults(config_group_overrides, defaults) # Load and defaults and merge them into cfg try: cfg = self._merge_defaults_into_config( hydra_cfg, job_cfg, job_cfg_load_trace, defaults, split_at, run_mode=run_mode, ) except UnspecifiedMandatoryDefault as e: options = self.get_group_options(e.config_group) opt_list = "\n".join(["\t" + x for x in options]) msg = ( f"You must specify '{e.config_group}', e.g, {e.config_group}=<OPTION>" f"\nAvailable options:" f"\n{opt_list}" ) raise ConfigCompositionException(msg) from e OmegaConf.set_struct(cfg, strict) # Apply command line overrides after enabling strict flag ConfigLoaderImpl._apply_overrides_to_config(config_overrides, cfg) app_overrides = [] for override in parsed_overrides: if override.is_hydra_override(): cfg.hydra.overrides.hydra.append(override.input_line) else: cfg.hydra.overrides.task.append(override.input_line) app_overrides.append(override) with open_dict(cfg.hydra.job): if "name" not in cfg.hydra.job: cfg.hydra.job.name = JobRuntime().get("name") cfg.hydra.job.override_dirname = get_overrides_dirname( overrides=app_overrides, kv_sep=cfg.hydra.job.config.override_dirname.kv_sep, item_sep=cfg.hydra.job.config.override_dirname.item_sep, exclude_keys=cfg.hydra.job.config.override_dirname.exclude_keys, ) cfg.hydra.job.config_name = config_name for key in cfg.hydra.job.env_copy: cfg.hydra.job.env_set[key] = os.environ[key] return cfg