示例#1
0
    def _apply_overrides_to_config(overrides: List[str],
                                   cfg: DictConfig) -> None:
        loader = _utils.get_yaml_loader()

        def get_value(val_: Optional[str]) -> Any:
            return yaml.load(val_, Loader=loader) if val_ is not None else None

        for line in overrides:
            override = ConfigLoaderImpl._parse_config_override(line)
            try:
                value = get_value(override.value)
                if override.is_delete():
                    val = OmegaConf.select(cfg,
                                           override.key,
                                           throw_on_missing=False)
                    if val is None:
                        raise HydraException(
                            f"Could not delete from config. '{override.key}' does not exist."
                        )
                    elif value is not None and value != val:
                        raise HydraException(
                            f"Could not delete from config."
                            f" The value of '{override.key}' is {val} and not {override.value}."
                        )

                    key = override.key
                    last_dot = key.rfind(".")
                    with open_dict(cfg):
                        if last_dot == -1:
                            del cfg[key]
                        else:
                            node = OmegaConf.select(cfg, key[0:last_dot])
                            del node[key[last_dot + 1:]]

                elif override.is_add():
                    if (OmegaConf.select(cfg,
                                         override.key,
                                         throw_on_missing=False) is None):
                        with open_dict(cfg):
                            OmegaConf.update(cfg, override.key, value)
                    else:
                        raise HydraException(
                            f"Could not append to config. An item is already at '{override.key}'."
                        )
                else:
                    try:
                        OmegaConf.update(cfg, override.key, value)
                    except (ConfigAttributeError, ConfigKeyError) as ex:
                        raise HydraException(
                            f"Could not override '{override.key}'. No match in config."
                            f"\nTo append to your config use +{line}") from ex
            except OmegaConfBaseException as ex:
                raise HydraException(f"Error merging override {line}") from ex
示例#2
0
class BasicSweeper(Sweeper):
    """
    Basic sweeper
    """

    def __init__(self, max_batch_size: Optional[int]) -> None:
        """
        Instantiates
        """
        super(BasicSweeper, self).__init__()
        self.overrides: Optional[Sequence[Sequence[Sequence[str]]]] = None
        self.batch_index = 0
        self.max_batch_size = max_batch_size

    def setup(
        self,
        config: DictConfig,
        config_loader: ConfigLoader,
        task_function: TaskFunction,
    ) -> None:
        from hydra.core.plugins import Plugins

        self.config = config

        self.launcher = Plugins.instance().instantiate_launcher(
            config=config, config_loader=config_loader, task_function=task_function
        )

    loader = get_yaml_loader()

    @staticmethod
    def split_overrides_to_chunks(
        lst: List[List[str]], n: Optional[int]
    ) -> Iterable[List[List[str]]]:
        if n is None or n == -1:
            n = len(lst)
        assert n > 0
        for i in range(0, len(lst), n):
            yield lst[i : i + n]

    @staticmethod
    def split_arguments(
        arguments: List[str], max_batch_size: Optional[int]
    ) -> List[List[List[str]]]:
        parser = OverridesParser()
        parsed = parser.parse_overrides(arguments)

        lists = []
        for override in parsed:
            if override.is_sweep_override():
                sweep_choices = override.choices_as_strings()
                assert isinstance(sweep_choices, list)
                key = override.get_key_element()
                sweep = [f"{key}={val}" for val in sweep_choices]
                lists.append(sweep)
            else:
                key = override.get_key_element()
                value = override.get_value_element()
                lists.append([f"{key}={value}"])
        all_batches = [list(x) for x in itertools.product(*lists)]
        assert max_batch_size is None or max_batch_size > 0
        if max_batch_size is None:
            return [all_batches]
        else:
            chunks_iter = BasicSweeper.split_overrides_to_chunks(
                all_batches, max_batch_size
            )
            return [x for x in chunks_iter]

    def sweep(self, arguments: List[str]) -> Any:
        assert self.config is not None
        assert self.launcher is not None
        self.overrides = self.split_arguments(arguments, self.max_batch_size)
        returns: List[Sequence[JobReturn]] = []

        # Save sweep run config in top level sweep working directory
        sweep_dir = Path(self.config.hydra.sweep.dir)
        sweep_dir.mkdir(parents=True, exist_ok=True)
        OmegaConf.save(self.config, sweep_dir / "multirun.yaml")

        initial_job_idx = 0
        while not self.is_done():
            batch = self.get_job_batch()
            results = self.launcher.launch(batch, initial_job_idx=initial_job_idx)
            initial_job_idx += len(batch)
            returns.append(results)
        return returns

    def get_job_batch(self) -> Sequence[Sequence[str]]:
        """
        :return: A list of lists of strings, each inner list is the overrides for a single job
        that should be executed.
        """
        assert self.overrides is not None
        self.batch_index += 1
        return self.overrides[self.batch_index - 1]

    def is_done(self) -> bool:
        assert self.overrides is not None
        return self.batch_index >= len(self.overrides)