def _apply_overrides_to_config(overrides: List[str], cfg: DictConfig) -> None: loader = _utils.get_yaml_loader() def get_value(val_: Optional[str]) -> Any: return yaml.load(val_, Loader=loader) if val_ is not None else None for line in overrides: override = ConfigLoaderImpl._parse_config_override(line) try: value = get_value(override.value) if override.is_delete(): val = OmegaConf.select(cfg, override.key, throw_on_missing=False) if val is None: raise HydraException( f"Could not delete from config. '{override.key}' does not exist." ) elif value is not None and value != val: raise HydraException( f"Could not delete from config." f" The value of '{override.key}' is {val} and not {override.value}." ) key = override.key last_dot = key.rfind(".") with open_dict(cfg): if last_dot == -1: del cfg[key] else: node = OmegaConf.select(cfg, key[0:last_dot]) del node[key[last_dot + 1:]] elif override.is_add(): if (OmegaConf.select(cfg, override.key, throw_on_missing=False) is None): with open_dict(cfg): OmegaConf.update(cfg, override.key, value) else: raise HydraException( f"Could not append to config. An item is already at '{override.key}'." ) else: try: OmegaConf.update(cfg, override.key, value) except (ConfigAttributeError, ConfigKeyError) as ex: raise HydraException( f"Could not override '{override.key}'. No match in config." f"\nTo append to your config use +{line}") from ex except OmegaConfBaseException as ex: raise HydraException(f"Error merging override {line}") from ex
class BasicSweeper(Sweeper): """ Basic sweeper """ def __init__(self, max_batch_size: Optional[int]) -> None: """ Instantiates """ super(BasicSweeper, self).__init__() self.overrides: Optional[Sequence[Sequence[Sequence[str]]]] = None self.batch_index = 0 self.max_batch_size = max_batch_size def setup( self, config: DictConfig, config_loader: ConfigLoader, task_function: TaskFunction, ) -> None: from hydra.core.plugins import Plugins self.config = config self.launcher = Plugins.instance().instantiate_launcher( config=config, config_loader=config_loader, task_function=task_function ) loader = get_yaml_loader() @staticmethod def split_overrides_to_chunks( lst: List[List[str]], n: Optional[int] ) -> Iterable[List[List[str]]]: if n is None or n == -1: n = len(lst) assert n > 0 for i in range(0, len(lst), n): yield lst[i : i + n] @staticmethod def split_arguments( arguments: List[str], max_batch_size: Optional[int] ) -> List[List[List[str]]]: parser = OverridesParser() parsed = parser.parse_overrides(arguments) lists = [] for override in parsed: if override.is_sweep_override(): sweep_choices = override.choices_as_strings() assert isinstance(sweep_choices, list) key = override.get_key_element() sweep = [f"{key}={val}" for val in sweep_choices] lists.append(sweep) else: key = override.get_key_element() value = override.get_value_element() lists.append([f"{key}={value}"]) all_batches = [list(x) for x in itertools.product(*lists)] assert max_batch_size is None or max_batch_size > 0 if max_batch_size is None: return [all_batches] else: chunks_iter = BasicSweeper.split_overrides_to_chunks( all_batches, max_batch_size ) return [x for x in chunks_iter] def sweep(self, arguments: List[str]) -> Any: assert self.config is not None assert self.launcher is not None self.overrides = self.split_arguments(arguments, self.max_batch_size) returns: List[Sequence[JobReturn]] = [] # Save sweep run config in top level sweep working directory sweep_dir = Path(self.config.hydra.sweep.dir) sweep_dir.mkdir(parents=True, exist_ok=True) OmegaConf.save(self.config, sweep_dir / "multirun.yaml") initial_job_idx = 0 while not self.is_done(): batch = self.get_job_batch() results = self.launcher.launch(batch, initial_job_idx=initial_job_idx) initial_job_idx += len(batch) returns.append(results) return returns def get_job_batch(self) -> Sequence[Sequence[str]]: """ :return: A list of lists of strings, each inner list is the overrides for a single job that should be executed. """ assert self.overrides is not None self.batch_index += 1 return self.overrides[self.batch_index - 1] def is_done(self) -> bool: assert self.overrides is not None return self.batch_index >= len(self.overrides)