base_configs = sorted(
            glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
        opt.base = base_configs + opt.base

    if opt.config:
        if type(opt.config) == str:
            opt.base = [opt.config]
        else:
            opt.base = [opt.base[-1]]

    configs = [OmegaConf.load(cfg) for cfg in opt.base]
    cli = OmegaConf.from_dotlist(unknown)
    if opt.ignore_base_data:
        for config in configs:
            if hasattr(config, "data"): del config["data"]
    config = OmegaConf.merge(*configs, cli)

    st.sidebar.text(ckpt)
    gs = st.sidebar.empty()
    gs.text(f"Global step: ?")
    st.sidebar.text("Options")
    #gpu = st.sidebar.checkbox("GPU", value=True)
    gpu = True
    #eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
    eval_mode = True
    #show_config = st.sidebar.checkbox("Show Config", value=False)
    show_config = False
    if show_config:
        st.info("Checkpoint: {}".format(ckpt))
        st.json(OmegaConf.to_container(config))
Ejemplo n.º 2
0
 def test_merge_list_with_correct_type(self, class_type: str) -> None:
     module: Any = import_module(class_type)
     cfg = OmegaConf.structured(module.UserList)
     user = module.User(name="John", age=21)
     res = OmegaConf.merge(cfg, {"list": [user]})
     assert res.list == [user]
Ejemplo n.º 3
0
 def test_merge_dict_with_correct_type(self, class_type: str) -> None:
     module: Any = import_module(class_type)
     cfg = OmegaConf.structured(module.UserDict)
     user = module.User(name="John", age=21)
     res = OmegaConf.merge(cfg, {"dict": {"foo": user}})
     assert res.dict == {"foo": user}
Ejemplo n.º 4
0
    def _load_single_config(self, default: ResultDefault,
                            repo: IConfigRepository) -> ConfigResult:
        config_path = default.config_path

        assert config_path is not None
        ret = repo.load_config(config_path=config_path)
        assert ret is not None

        if not OmegaConf.is_config(ret.config):
            raise ValueError(
                f"Config {config_path} must be an OmegaConf config, got {type(ret.config).__name__}"
            )

        if not ret.is_schema_source:
            schema = None
            try:
                schema_source = repo.get_schema_source()
                cname = ConfigSource._normalize_file_name(filename=config_path)
                schema = schema_source.load_config(cname)
            except ConfigLoadError:
                # schema not found, ignore
                pass

            if schema is not None:
                try:
                    url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/automatic_schema_matching"
                    if "defaults" in schema.config:
                        raise ConfigCompositionException(
                            dedent(f"""\
                            '{config_path}' is validated against ConfigStore schema with the same name.
                            This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
                            In addition, the automatically matched schema contains a defaults list.
                            This combination is no longer supported.
                            See {url} for migration instructions."""))
                    else:
                        deprecation_warning(
                            dedent(f"""\

                                '{config_path}' is validated against ConfigStore schema with the same name.
                                This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
                                See {url} for migration instructions."""),
                            stacklevel=11,
                        )

                    # if primary config has a hydra node, remove it during validation and add it back.
                    # This allows overriding Hydra's configuration without declaring it's node
                    # in the schema of every primary config
                    hydra = None
                    hydra_config_group = (
                        default.config_path is not None
                        and default.config_path.startswith("hydra/"))
                    config = ret.config
                    if (default.primary and isinstance(config, DictConfig)
                            and "hydra" in config and not hydra_config_group):
                        hydra = config.pop("hydra")

                    merged = OmegaConf.merge(schema.config, config)
                    assert isinstance(merged, DictConfig)

                    if hydra is not None:
                        with open_dict(merged):
                            merged.hydra = hydra
                    ret.config = merged
                except OmegaConfBaseException as e:
                    raise ConfigCompositionException(
                        f"Error merging '{config_path}' with schema") from e

                assert isinstance(merged, DictConfig)

        res = self._embed_result_config(ret, default.package)
        if (not default.primary and config_path != "hydra/config"
                and isinstance(res.config, DictConfig) and OmegaConf.select(
                    res.config, "hydra.searchpath") is not None):
            raise ConfigCompositionException(
                f"In '{config_path}': Overriding hydra.searchpath is only supported from the primary config"
            )

        return res
Ejemplo n.º 5
0
 def test_merge_into_Dict(self, class_type: str) -> None:
     module: Any = import_module(class_type)
     cfg = OmegaConf.structured(module.DictExamples)
     res = OmegaConf.merge(cfg, {"strings": {"x": "abc"}})
     assert res.strings == {"a": "foo", "b": "bar", "x": "abc"}
Ejemplo n.º 6
0
def test_primitive_dicts() -> None:
    c1 = {"a": 10}
    c2 = {"b": 20}
    merged = OmegaConf.merge(c1, c2)
    assert merged == {"a": 10, "b": 20}
Ejemplo n.º 7
0
def test_merge_error(base: Any, merge: Any, exception: Any) -> None:
    base = OmegaConf.create(base)
    merge = None if merge is None else OmegaConf.create(merge)
    with pytest.raises(exception):
        OmegaConf.merge(base, merge)
Ejemplo n.º 8
0
def load_yaml_with_defaults(f):
    default_config = get_default_config_path()
    return OmegaConf.merge(load_yaml(default_config), load_yaml(f))
Ejemplo n.º 9
0
 def test_merge_structured_onto_dict(self, module: Any) -> None:
     c1 = OmegaConf.create({"name": 7})
     c2 = OmegaConf.merge(c1, module.User)
     assert c1 == {"name": 7}
     # type of name becomes str
     assert c2 == {"name": "7", "age": "???"}
Ejemplo n.º 10
0
 def _test_merge(self, ref_type: Any, value: Any, assign: Any,
                 expectation: Any) -> None:
     cfg = OmegaConf.create(
         {"foo": DictConfig(ref_type=ref_type, content=value)})
     with expectation:
         OmegaConf.merge(cfg, {"foo": assign})
Ejemplo n.º 11
0
def merge_with_parent(dc: FairseqDataclass, cfg: FairseqDataclass):
    merged_cfg = OmegaConf.merge(dc, cfg)
    merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
    OmegaConf.set_struct(merged_cfg, True)
    return merged_cfg
Ejemplo n.º 12
0
    def _load_single_config(
        self, default: ResultDefault, repo: IConfigRepository
    ) -> Tuple[ConfigResult, LoadTrace]:
        config_path = default.config_path

        assert config_path is not None
        ret = repo.load_config(config_path=config_path)
        assert ret is not None

        if not isinstance(ret.config, DictConfig):
            raise ValueError(
                f"Config {config_path} must be a Dictionary, got {type(ret).__name__}"
            )

        if not ret.is_schema_source:
            schema = None
            try:
                schema_source = repo.get_schema_source()
                cname = ConfigSource._normalize_file_name(filename=config_path)
                schema = schema_source.load_config(cname)
            except ConfigLoadError:
                # schema not found, ignore
                pass

            if schema is not None:
                try:
                    # TODO: deprecate schema matching in favor of extension via Defaults List
                    # if primary config has a hydra node, remove it during validation and add it back.
                    # This allows overriding Hydra's configuration without declaring it's node
                    # in the schema of every primary config
                    hydra = None
                    hydra_config_group = (
                        default.config_path is not None
                        and default.config_path.startswith("hydra/")
                    )
                    if (
                        default.primary
                        and "hydra" in ret.config
                        and not hydra_config_group
                    ):
                        hydra = ret.config.pop("hydra")

                    merged = OmegaConf.merge(schema.config, ret.config)
                    assert isinstance(merged, DictConfig)

                    if hydra is not None:
                        with open_dict(merged):
                            merged.hydra = hydra
                    ret.config = merged
                except OmegaConfBaseException as e:
                    raise ConfigCompositionException(
                        f"Error merging '{config_path}' with schema"
                    ) from e

                assert isinstance(merged, DictConfig)

        trace = LoadTrace(
            config_path=default.config_path,
            package=default.package,
            parent=default.parent,
            is_self=default.is_self,
            search_path=ret.path,
            provider=ret.provider,
        )

        ret = self._embed_result_config(ret, default.package)

        return ret, trace
Ejemplo n.º 13
0
def exp_manager(trainer: 'pytorch_lightning.Trainer',
                cfg: Optional[Union[DictConfig, Dict]] = None) -> Path:
    """
    exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning paradigm
    of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will get exp_dir,
    name, and version from the logger. Otherwise it will use the exp_dir and name arguments to create the logging
    directory. exp_manager also allows for explicit folder creation via explicit_log_dir.
    The version will be a datetime string or an integer. Note, exp_manager does not handle versioning on slurm
    multi-node runs. Datestime version can be disabled if use_datetime_version is set to False.
    It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch lightning. It copies
    sys.argv, and git information if available to the logging directory. It creates a log file for each process to log
    their output into.
    exp_manager additionally has a resume feature which can be used to continuing training from the constructed log_dir.

    Args:
        trainer (pytorch_lightning.Trainer): The lightning trainer.
        cfg (DictConfig, dict): Can have the following keys:
            - explicit_log_dir (str, Path): Can be used to override exp_dir/name/version folder creation. Defaults to
                None, which will use exp_dir, name, and version to construct the logging directory.
            - exp_dir (str, Path): The base directory to create the logging directory. Defaults to None, which logs to
                ./nemo_experiments.
            - name (str): The name of the experiment. Defaults to None which turns into "default" via name = name or
                "default".
            - version (str): The version of the experiment. Defaults to None which uses either a datetime string or
                lightning's TensorboardLogger system of using version_{int}.
            - use_datetime_version (bool): Whether to use a datetime string for version. Defaults to True.
            - resume_if_exists (bool): Whether this experiment is resuming from a previous run. If True, it sets
                trainer.resume_from_checkpoint so that the trainer should auto-resume. exp_manager will move files
                under log_dir to log_dir/run_{int}. Defaults to False.
            - resume_past_end (bool): exp_manager errors out if resume_if_exists is True and a checkpoint matching
                *end.ckpt indicating a previous training run fully completed. This behaviour can be disabled, in which
                case the *end.ckpt will be loaded by setting resume_past_end to True. Defaults to False.
            - resume_ignore_no_checkpoint (bool): exp_manager errors out if resume_if_exists is True and no checkpoint
                could be found. This behaviour can be disabled, in which case exp_manager will print a message and
                continue without restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False.
            - create_tensorboard_logger (bool): Whether to create a tensorboard logger and attach it to the pytorch
                lightning trainer. Defaults to True.
            - summary_writer_kwargs (dict): A dictionary of kwargs that can be passed to lightning's TensorboardLogger
                class. Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None.
            - create_wandb_logger (bool): Whether to create a Weights and Baises logger and attach it to the pytorch
                lightning trainer. Defaults to False.
            - wandb_logger_kwargs (dict): A dictionary of kwargs that can be passed to lightning's WandBLogger
                class. Note that name and project are required parameters if create_wandb_logger is True.
                Defaults to None.
            - create_checkpoint_callback (bool): Whether to create a ModelCheckpoint callback and attach it to the
                pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most
                recent checkpoint under *last.ckpt, and the final checkpoint after training completes under *end.ckpt.
                Defaults to True.
            - files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which
                copies no files.

    returns:
        log_dir (Path): The final logging directory where logging files are saved. Usually the concatenation of
            exp_dir, name, and version.
    """
    # Add rank information to logger
    # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it
    global_rank = trainer.node_rank * trainer.num_gpus + int(
        os.environ.get("LOCAL_RANK", 0))
    logging.rank = global_rank

    if cfg is None:
        logging.error(
            "exp_manager did not receive a cfg argument. It will be disabled.")
        return
    if trainer.fast_dev_run:
        logging.info(
            "Trainer was called with fast_dev_run. exp_manager will return without any functionality."
        )
        return

    # Ensure passed cfg is compliant with ExpManagerConfig
    schema = OmegaConf.structured(ExpManagerConfig)
    if isinstance(cfg, dict):
        cfg = OmegaConf.create(cfg)
    elif not isinstance(cfg, DictConfig):
        raise ValueError(
            f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
        )
    cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True))
    cfg = OmegaConf.merge(schema, cfg)

    error_checks(
        trainer, cfg
    )  # Ensures that trainer options are compliant with NeMo and exp_manager arguments

    log_dir, exp_dir, name, version = get_log_dir(
        trainer=trainer,
        exp_dir=cfg.exp_dir,
        name=cfg.name,
        version=cfg.version,
        explicit_log_dir=cfg.explicit_log_dir,
        use_datetime_version=cfg.use_datetime_version,
    )

    if cfg.resume_if_exists:
        check_resume(trainer, log_dir, cfg.resume_past_end,
                     cfg.resume_ignore_no_checkpoint)

    checkpoint_name = name
    # If name returned from get_log_dir is "", use cfg.name for checkpointing
    if checkpoint_name is None or checkpoint_name == '':
        checkpoint_name = cfg.name or "default"
    cfg.name = name  # Used for configure_loggers so that the log_dir is properly set even if name is ""
    cfg.version = version

    # update app_state with log_dir, exp_dir, etc
    app_state = AppState()
    app_state.log_dir = log_dir
    app_state.exp_dir = exp_dir
    app_state.name = name
    app_state.version = version
    app_state.checkpoint_name = checkpoint_name
    app_state.create_checkpoint_callback = cfg.create_checkpoint_callback
    app_state.checkpoint_callback_params = cfg.checkpoint_callback_params

    # Create the logging directory if it does not exist
    os.makedirs(
        log_dir, exist_ok=True
    )  # Cannot limit creation to global zero as all ranks write to own log file
    logging.info(f'Experiments will be logged at {log_dir}')
    trainer._default_root_dir = log_dir

    # Handle Loggers by creating file and handle DEBUG statements
    log_file = log_dir / f'nemo_log_globalrank-{global_rank}_localrank-{int(os.environ.get("LOCAL_RANK", 0))}.txt'
    logging.add_file_handler(log_file)

    # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks
    # not just global rank 0.
    if cfg.create_tensorboard_logger or cfg.create_wandb_logger:
        configure_loggers(
            trainer,
            exp_dir,
            cfg.name,
            cfg.version,
            cfg.create_tensorboard_logger,
            cfg.summary_writer_kwargs,
            cfg.create_wandb_logger,
            cfg.wandb_logger_kwargs,
        )

    if cfg.create_checkpoint_callback:
        configure_checkpointing(trainer, log_dir, checkpoint_name,
                                cfg.checkpoint_callback_params)

    if is_global_rank_zero():
        # Move files_to_copy to folder and add git information if present
        if cfg.files_to_copy:
            for _file in cfg.files_to_copy:
                copy(Path(_file), log_dir)

        # Create files for cmd args and git info
        with open(log_dir / 'cmd-args.log', 'w') as _file:
            _file.write(" ".join(sys.argv))

        # Try to get git hash
        git_repo, git_hash = get_git_hash()
        if git_repo:
            with open(log_dir / 'git-info.log', 'w') as _file:
                _file.write(f'commit hash: {git_hash}')
                _file.write(get_git_diff())

        # Add err_file logging to global_rank zero
        logging.add_err_file_handler(log_dir / 'nemo_error_log.txt')

        # Add lightning file logging to global_rank zero
        add_filehandlers_to_pl_logger(log_dir / 'lightning_logs.txt',
                                      log_dir / 'nemo_error_log.txt')

    return log_dir
Ejemplo n.º 14
0
 ),
 pytest.param(
     Expected(
         create=lambda: create_readonly({"foo1": "bar"}),
         op=lambda cfg: cfg.merge_with({"foo2": "bar"}),
         exception_type=ReadonlyConfigError,
         key="foo2",
         msg="Cannot change read-only config container",
     ),
     id="dict,readonly:merge_with",
 ),
 pytest.param(
     Expected(
         create=lambda: OmegaConf.structured(ConcretePlugin),
         op=lambda cfg: OmegaConf.merge(cfg, {"params": {
             "foo": "bar"
         }}),
         exception_type=ValidationError,
         msg="Value 'bar' could not be converted to Integer",
         key="foo",
         full_key="params.foo",
         object_type=ConcretePlugin.FoobarParams,
         child_node=lambda cfg: cfg.params.foo,
         parent_node=lambda cfg: cfg.params,
     ),
     id="structured:merge,invalid_field_type",
 ),
 pytest.param(
     Expected(
         create=lambda: OmegaConf.structured(ConcretePlugin),
         op=lambda cfg: OmegaConf.merge(cfg, {"params": {
Ejemplo n.º 15
0
    def load_configuration(
        self,
        config_name: Optional[str],
        overrides: List[str],
        strict: Optional[bool] = None,
    ) -> DictConfig:
        assert config_name is None or isinstance(config_name, str)
        assert strict is None or isinstance(strict, bool)
        assert isinstance(overrides, list)
        if strict is None:
            strict = self.default_strict

        assert overrides is None or isinstance(overrides, list)
        overrides = copy.deepcopy(overrides) or []

        if config_name is not None and not self.exists_in_search_path(
                config_name):
            # TODO: handle schema as a special case
            descs = [
                f"\t{src.path} (from {src.provider})"
                for src in self.repository.get_sources()
            ]
            lines = "\n".join(descs)
            raise MissingConfigException(
                missing_cfg_file=config_name,
                message=
                f"Cannot find primary config file: {config_name}\nSearch path:\n{lines}",
            )

        # Load hydra config
        hydra_cfg, _load_trace = self._create_cfg(cfg_filename="hydra_config")

        # Load job config
        job_cfg, job_cfg_load_trace = self._create_cfg(
            cfg_filename=config_name, record_load=False)

        job_defaults = ConfigLoaderImpl._get_defaults(job_cfg)
        defaults = ConfigLoaderImpl._get_defaults(hydra_cfg)

        job_cfg_type = OmegaConf.get_type(job_cfg)
        if job_cfg_type is not None and not issubclass(job_cfg_type, dict):
            hydra_cfg._promote(job_cfg_type)
            # this is breaking encapsulation a bit. can potentially be implemented in OmegaConf
            hydra_cfg._metadata.ref_type = job_cfg._metadata.ref_type

        # if defaults are re-introduced by the promotion, remove it.
        if "defaults" in hydra_cfg:
            with open_dict(hydra_cfg):
                del hydra_cfg["defaults"]

        if config_name is not None:
            defaults.append("__SELF__")
        split_at = len(defaults)

        ConfigLoaderImpl._merge_default_lists(defaults, job_defaults)
        consumed = self._apply_defaults_overrides(overrides, defaults)

        consumed_free_job_defaults = self._apply_free_defaults(
            defaults, overrides)

        ConfigLoaderImpl._validate_defaults(defaults)

        # Load and defaults and merge them into cfg
        cfg = self._merge_defaults(hydra_cfg, job_cfg, job_cfg_load_trace,
                                   defaults, split_at)
        OmegaConf.set_struct(cfg.hydra, True)
        OmegaConf.set_struct(cfg, strict)

        # Merge all command line overrides after enabling strict flag
        all_consumed = consumed + consumed_free_job_defaults
        remaining_overrides = [x for x in overrides if x not in all_consumed]
        try:
            merged = OmegaConf.merge(
                cfg, OmegaConf.from_dotlist(remaining_overrides))
            assert isinstance(merged, DictConfig)
            cfg = merged
        except OmegaConfBaseException as ex:
            raise HydraException("Error merging overrides") from ex

        remaining = consumed + consumed_free_job_defaults + remaining_overrides

        def is_hydra(x: str) -> bool:
            return x.startswith("hydra.") or x.startswith("hydra/")

        cfg.hydra.overrides.task = [x for x in remaining if not is_hydra(x)]
        cfg.hydra.overrides.hydra = [x for x in remaining if is_hydra(x)]

        with open_dict(cfg.hydra.job):
            if "name" not in cfg.hydra.job:
                cfg.hydra.job.name = JobRuntime().get("name")
            cfg.hydra.job.override_dirname = get_overrides_dirname(
                input_list=cfg.hydra.overrides.task,
                kv_sep=cfg.hydra.job.config.override_dirname.kv_sep,
                item_sep=cfg.hydra.job.config.override_dirname.item_sep,
                exclude_keys=cfg.hydra.job.config.override_dirname.
                exclude_keys,
            )
            cfg.hydra.job.config_name = config_name

            for key in cfg.hydra.job.env_copy:
                cfg.hydra.job.env_set[key] = os.environ[key]

        return cfg
Ejemplo n.º 16
0
 def test_merge_missing_object_onto_typed_dictconfig(self,
                                                     module: Any) -> None:
     c1 = OmegaConf.structured(module.DictOfObjects)
     c2 = OmegaConf.merge(c1, {"users": {"bob": "???"}})
     assert isinstance(c2, DictConfig)
     assert OmegaConf.is_missing(c2.users, "bob")
Ejemplo n.º 17
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)

        self._parser = parsers.make_parser(
            labels=cfg.labels,
            name='en',
            unk_id=-1,
            blank_id=-1,
            do_normalize=True,
            abbreviation_version="fastpitch",
            make_table=False,
        )

        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(FastPitchHifiGanE2EConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.preprocessor = instantiate(cfg.preprocessor)
        self.melspec_fn = instantiate(cfg.preprocessor,
                                      highfreq=None,
                                      use_grads=True)

        self.encoder = instantiate(cfg.input_fft)
        self.duration_predictor = instantiate(cfg.duration_predictor)
        self.pitch_predictor = instantiate(cfg.pitch_predictor)

        self.generator = instantiate(cfg.generator)
        self.multiperioddisc = MultiPeriodDiscriminator()
        self.multiscaledisc = MultiScaleDiscriminator()
        self.mel_val_loss = L1MelLoss()
        self.feat_matching_loss = FeatureMatchingLoss()
        self.disc_loss = DiscriminatorLoss()
        self.gen_loss = GeneratorLoss()

        self.max_token_duration = cfg.max_token_duration

        self.pitch_emb = torch.nn.Conv1d(
            1,
            cfg.symbols_embedding_dim,
            kernel_size=cfg.pitch_embedding_kernel_size,
            padding=int((cfg.pitch_embedding_kernel_size - 1) / 2),
        )

        # Store values precomputed from training data for convenience
        self.register_buffer('pitch_mean', torch.zeros(1))
        self.register_buffer('pitch_std', torch.zeros(1))

        self.loss = BaseFastPitchLoss()

        self.mel_loss_coeff = cfg.mel_loss_coeff

        self.log_train_images = False
        self.logged_real_samples = False
        self._tb_logger = None
        self.hann_window = None
        self.splice_length = cfg.splice_length
        self.sample_rate = cfg.sample_rate
        self.hop_size = cfg.hop_size
Ejemplo n.º 18
0
 def test_merge_into_missing_sc(self, module: Any) -> None:
     c1 = OmegaConf.structured(module.PluginHolder)
     c2 = OmegaConf.merge(c1, {"plugin": "???"})
     assert c2.plugin == module.Plugin()
Ejemplo n.º 19
0
def test_3way_dict_merge() -> None:
    c1 = OmegaConf.create("{a: 1, b: 2}")
    c2 = OmegaConf.create("{b: 3}")
    c3 = OmegaConf.create("{a: 2, c: 3}")
    c4 = OmegaConf.merge(c1, c2, c3)
    assert {"a": 2, "b": 3, "c": 3} == c4
Ejemplo n.º 20
0
 def test_plugin_merge(self, module: Any) -> None:
     plugin = OmegaConf.structured(module.Plugin)
     concrete = OmegaConf.structured(module.ConcretePlugin)
     ret = OmegaConf.merge(plugin, concrete)
     assert ret == concrete
     assert OmegaConf.get_type(ret) == module.ConcretePlugin
Ejemplo n.º 21
0
def test_with_readonly(c1: Any, c2: Any) -> None:
    cfg = OmegaConf.create(c1)
    OmegaConf.set_readonly(cfg, True)
    cfg2 = OmegaConf.merge(cfg, c2)
    assert OmegaConf.is_readonly(cfg2)
Ejemplo n.º 22
0
 def test_merge_of_non_subclass_1(self, module: Any) -> None:
     cfg1 = OmegaConf.create({"plugin": module.Plugin})
     cfg2 = OmegaConf.create({"plugin": module.FaultyPlugin})
     with raises(ValidationError):
         OmegaConf.merge(cfg1, cfg2)
Ejemplo n.º 23
0
 def test_merged_with_nons_subclass(self, class_type: str) -> None:
     module: Any = import_module(class_type)
     c1 = OmegaConf.structured(module.Plugin)
     c2 = OmegaConf.structured(module.FaultyPlugin)
     with pytest.raises(ValidationError):
         OmegaConf.merge(c1, c2)
Ejemplo n.º 24
0
 def test_merge_error_new_attribute(self, module: Any) -> None:
     cfg = OmegaConf.structured(module.ConcretePlugin)
     cfg2 = OmegaConf.create({"params": {"bar": 10}})
     # raise if an invalid key is merged into a struct
     with raises(ConfigKeyError):
         OmegaConf.merge(cfg, cfg2)
Ejemplo n.º 25
0
 def test_merge_user_list_with_wrong_key(self, class_type: str) -> None:
     module: Any = import_module(class_type)
     cfg = OmegaConf.structured(module.UserList)
     with pytest.raises(ConfigKeyError):
         OmegaConf.merge(cfg, {"list": [{"foo": "var"}]})
Ejemplo n.º 26
0
        def test_merge_error_override_bad_type(self, module: Any) -> None:
            cfg = OmegaConf.structured(module.ConcretePlugin)

            # raise if an invalid key is merged into a struct
            with raises(ValidationError):
                OmegaConf.merge(cfg, {"params": {"foo": "zonk"}})
Ejemplo n.º 27
0
 def test_merge_dict_with_wrong_type(self, class_type: str) -> None:
     module: Any = import_module(class_type)
     cfg = OmegaConf.structured(module.UserDict)
     with pytest.raises(ValidationError):
         OmegaConf.merge(cfg, {"dict": {"foo": "var"}})
Ejemplo n.º 28
0
    def _load_config_impl(
        self,
        input_file: str,
        record_load: bool = True
    ) -> Tuple[Optional[DictConfig], Optional[LoadTrace]]:
        """
        :param input_file:
        :param record_load:
        :return: the loaded config or None if it was not found
        """
        def record_loading(
            name: str,
            path: Optional[str],
            provider: Optional[str],
            schema_provider: Optional[str],
        ) -> Optional[LoadTrace]:
            trace = LoadTrace(
                filename=name,
                path=path,
                provider=provider,
                schema_provider=schema_provider,
            )

            if record_load:
                self.all_config_checked.append(trace)

            return trace

        ret = self.repository.load_config(config_path=input_file)

        if ret is not None:
            if not isinstance(ret.config, DictConfig):
                raise ValueError(
                    f"Config {input_file} must be a Dictionary, got {type(ret).__name__}"
                )
            if not ret.is_schema_source:
                try:
                    schema = ConfigStore.instance().load(
                        config_path=ConfigSource._normalize_file_name(
                            filename=input_file))

                    merged = OmegaConf.merge(schema.node, ret.config)
                    assert isinstance(merged, DictConfig)
                    return (
                        merged,
                        record_loading(
                            name=input_file,
                            path=ret.path,
                            provider=ret.provider,
                            schema_provider=schema.provider,
                        ),
                    )

                except ConfigLoadError:
                    # schema not found, ignore
                    pass

            return (
                ret.config,
                record_loading(
                    name=input_file,
                    path=ret.path,
                    provider=ret.provider,
                    schema_provider=None,
                ),
            )
        else:
            return (
                None,
                record_loading(name=input_file,
                               path=None,
                               provider=None,
                               schema_provider=None),
            )
Ejemplo n.º 29
0
    def initialize(self, config_path: str, force: bool = False) -> OmegaConf:
        """
        Args:
            config_path: a file or dir
            force: ignore if initialized
        Returns:
            user_config: only return the user config
        """

        if self.initialized and not force:
            raise ValueError("FlyConfig is already initialized!")

        config_path = check_config_path(config_path)

        init_omegaconf()

        system_config = load_system_config()
        user_config = load_user_config(config_path)

        config = OmegaConf.merge(system_config, user_config)

        # get current working dir
        config.flyconfig.runtime.cwd = os.getcwd()

        # change working dir
        if not self.disable_chdir:
            working_dir_path = config.flyconfig.run.dir
            os.makedirs(working_dir_path, exist_ok=True)
            os.chdir(working_dir_path)

        # configure logging
        if int(os.environ.get("LOCAL_RANK",
                              0)) == 0 and not self.disable_logging:
            logging.config.dictConfig(
                OmegaConf.to_container(config.flyconfig.logging))
            logger.info("FlyConfig Initialized")
            if not self.disable_chdir:
                logger.info(
                    f"Working directory is changed to {working_dir_path}")

        # clean defaults
        del config["defaults"]

        # overrides
        overrides = get_overrides_from_argv(sys.argv[1:])
        overrides_config = OmegaConf.from_dotlist(overrides)
        config = OmegaConf.merge(config, overrides_config)

        # get system config
        self.system_config = OmegaConf.create(
            {"flyconfig": OmegaConf.to_container(config.flyconfig)})

        # get user config
        self.user_config = copy.deepcopy(config)
        del self.user_config["flyconfig"]

        # save config
        if int(os.environ.get("LOCAL_RANK",
                              0)) == 0 and not self.disable_logging:
            os.makedirs(self.system_config.flyconfig.output_subdir,
                        exist_ok=True)

            # save the entire config directory
            cwd = self.system_config.flyconfig.runtime.cwd
            dirpath = os.path.join(cwd, os.path.dirname(config_path))
            shutil.copytree(
                dirpath,
                os.path.join(self.system_config.flyconfig.output_subdir,
                             "config"))

            # save system config
            _save_config(filepath=os.path.join(
                self.system_config.flyconfig.output_subdir, "flyconfig.yml"),
                         config=self.system_config)

            # save user config
            _save_config(filepath=os.path.join(
                self.system_config.flyconfig.output_subdir, "config.yml"),
                         config=self.user_config)

            logger.info("\n\nConfiguration:\n" + self.user_config.pretty())

        self.initialized = True

        return self.user_config
Ejemplo n.º 30
0
from hydra.utils import get_class, instantiate
from omegaconf import OmegaConf
from torch import Tensor

cfg = {} 

full_class = f"gen.configen_tests.utils.data.dataset.TensorDatasetConf"
schema = OmegaConf.structured(get_class(full_class))
cfg = OmegaConf.merge(schema, cfg)
obj = instantiate(cfg, tensors=(Tensor([1])))

print(obj)