base_configs = sorted( glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) opt.base = base_configs + opt.base if opt.config: if type(opt.config) == str: opt.base = [opt.config] else: opt.base = [opt.base[-1]] configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) if opt.ignore_base_data: for config in configs: if hasattr(config, "data"): del config["data"] config = OmegaConf.merge(*configs, cli) st.sidebar.text(ckpt) gs = st.sidebar.empty() gs.text(f"Global step: ?") st.sidebar.text("Options") #gpu = st.sidebar.checkbox("GPU", value=True) gpu = True #eval_mode = st.sidebar.checkbox("Eval Mode", value=True) eval_mode = True #show_config = st.sidebar.checkbox("Show Config", value=False) show_config = False if show_config: st.info("Checkpoint: {}".format(ckpt)) st.json(OmegaConf.to_container(config))
def test_merge_list_with_correct_type(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.UserList) user = module.User(name="John", age=21) res = OmegaConf.merge(cfg, {"list": [user]}) assert res.list == [user]
def test_merge_dict_with_correct_type(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.UserDict) user = module.User(name="John", age=21) res = OmegaConf.merge(cfg, {"dict": {"foo": user}}) assert res.dict == {"foo": user}
def _load_single_config(self, default: ResultDefault, repo: IConfigRepository) -> ConfigResult: config_path = default.config_path assert config_path is not None ret = repo.load_config(config_path=config_path) assert ret is not None if not OmegaConf.is_config(ret.config): raise ValueError( f"Config {config_path} must be an OmegaConf config, got {type(ret.config).__name__}" ) if not ret.is_schema_source: schema = None try: schema_source = repo.get_schema_source() cname = ConfigSource._normalize_file_name(filename=config_path) schema = schema_source.load_config(cname) except ConfigLoadError: # schema not found, ignore pass if schema is not None: try: url = "https://hydra.cc/docs/next/upgrades/1.0_to_1.1/automatic_schema_matching" if "defaults" in schema.config: raise ConfigCompositionException( dedent(f"""\ '{config_path}' is validated against ConfigStore schema with the same name. This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2. In addition, the automatically matched schema contains a defaults list. This combination is no longer supported. See {url} for migration instructions.""")) else: deprecation_warning( dedent(f"""\ '{config_path}' is validated against ConfigStore schema with the same name. This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2. See {url} for migration instructions."""), stacklevel=11, ) # if primary config has a hydra node, remove it during validation and add it back. # This allows overriding Hydra's configuration without declaring it's node # in the schema of every primary config hydra = None hydra_config_group = ( default.config_path is not None and default.config_path.startswith("hydra/")) config = ret.config if (default.primary and isinstance(config, DictConfig) and "hydra" in config and not hydra_config_group): hydra = config.pop("hydra") merged = OmegaConf.merge(schema.config, config) assert isinstance(merged, DictConfig) if hydra is not None: with open_dict(merged): merged.hydra = hydra ret.config = merged except OmegaConfBaseException as e: raise ConfigCompositionException( f"Error merging '{config_path}' with schema") from e assert isinstance(merged, DictConfig) res = self._embed_result_config(ret, default.package) if (not default.primary and config_path != "hydra/config" and isinstance(res.config, DictConfig) and OmegaConf.select( res.config, "hydra.searchpath") is not None): raise ConfigCompositionException( f"In '{config_path}': Overriding hydra.searchpath is only supported from the primary config" ) return res
def test_merge_into_Dict(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.DictExamples) res = OmegaConf.merge(cfg, {"strings": {"x": "abc"}}) assert res.strings == {"a": "foo", "b": "bar", "x": "abc"}
def test_primitive_dicts() -> None: c1 = {"a": 10} c2 = {"b": 20} merged = OmegaConf.merge(c1, c2) assert merged == {"a": 10, "b": 20}
def test_merge_error(base: Any, merge: Any, exception: Any) -> None: base = OmegaConf.create(base) merge = None if merge is None else OmegaConf.create(merge) with pytest.raises(exception): OmegaConf.merge(base, merge)
def load_yaml_with_defaults(f): default_config = get_default_config_path() return OmegaConf.merge(load_yaml(default_config), load_yaml(f))
def test_merge_structured_onto_dict(self, module: Any) -> None: c1 = OmegaConf.create({"name": 7}) c2 = OmegaConf.merge(c1, module.User) assert c1 == {"name": 7} # type of name becomes str assert c2 == {"name": "7", "age": "???"}
def _test_merge(self, ref_type: Any, value: Any, assign: Any, expectation: Any) -> None: cfg = OmegaConf.create( {"foo": DictConfig(ref_type=ref_type, content=value)}) with expectation: OmegaConf.merge(cfg, {"foo": assign})
def merge_with_parent(dc: FairseqDataclass, cfg: FairseqDataclass): merged_cfg = OmegaConf.merge(dc, cfg) merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] OmegaConf.set_struct(merged_cfg, True) return merged_cfg
def _load_single_config( self, default: ResultDefault, repo: IConfigRepository ) -> Tuple[ConfigResult, LoadTrace]: config_path = default.config_path assert config_path is not None ret = repo.load_config(config_path=config_path) assert ret is not None if not isinstance(ret.config, DictConfig): raise ValueError( f"Config {config_path} must be a Dictionary, got {type(ret).__name__}" ) if not ret.is_schema_source: schema = None try: schema_source = repo.get_schema_source() cname = ConfigSource._normalize_file_name(filename=config_path) schema = schema_source.load_config(cname) except ConfigLoadError: # schema not found, ignore pass if schema is not None: try: # TODO: deprecate schema matching in favor of extension via Defaults List # if primary config has a hydra node, remove it during validation and add it back. # This allows overriding Hydra's configuration without declaring it's node # in the schema of every primary config hydra = None hydra_config_group = ( default.config_path is not None and default.config_path.startswith("hydra/") ) if ( default.primary and "hydra" in ret.config and not hydra_config_group ): hydra = ret.config.pop("hydra") merged = OmegaConf.merge(schema.config, ret.config) assert isinstance(merged, DictConfig) if hydra is not None: with open_dict(merged): merged.hydra = hydra ret.config = merged except OmegaConfBaseException as e: raise ConfigCompositionException( f"Error merging '{config_path}' with schema" ) from e assert isinstance(merged, DictConfig) trace = LoadTrace( config_path=default.config_path, package=default.package, parent=default.parent, is_self=default.is_self, search_path=ret.path, provider=ret.provider, ) ret = self._embed_result_config(ret, default.package) return ret, trace
def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None) -> Path: """ exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will get exp_dir, name, and version from the logger. Otherwise it will use the exp_dir and name arguments to create the logging directory. exp_manager also allows for explicit folder creation via explicit_log_dir. The version will be a datetime string or an integer. Note, exp_manager does not handle versioning on slurm multi-node runs. Datestime version can be disabled if use_datetime_version is set to False. It optionally creates TensorBoardLogger, WandBLogger, ModelCheckpoint objects from pytorch lightning. It copies sys.argv, and git information if available to the logging directory. It creates a log file for each process to log their output into. exp_manager additionally has a resume feature which can be used to continuing training from the constructed log_dir. Args: trainer (pytorch_lightning.Trainer): The lightning trainer. cfg (DictConfig, dict): Can have the following keys: - explicit_log_dir (str, Path): Can be used to override exp_dir/name/version folder creation. Defaults to None, which will use exp_dir, name, and version to construct the logging directory. - exp_dir (str, Path): The base directory to create the logging directory. Defaults to None, which logs to ./nemo_experiments. - name (str): The name of the experiment. Defaults to None which turns into "default" via name = name or "default". - version (str): The version of the experiment. Defaults to None which uses either a datetime string or lightning's TensorboardLogger system of using version_{int}. - use_datetime_version (bool): Whether to use a datetime string for version. Defaults to True. - resume_if_exists (bool): Whether this experiment is resuming from a previous run. If True, it sets trainer.resume_from_checkpoint so that the trainer should auto-resume. exp_manager will move files under log_dir to log_dir/run_{int}. Defaults to False. - resume_past_end (bool): exp_manager errors out if resume_if_exists is True and a checkpoint matching *end.ckpt indicating a previous training run fully completed. This behaviour can be disabled, in which case the *end.ckpt will be loaded by setting resume_past_end to True. Defaults to False. - resume_ignore_no_checkpoint (bool): exp_manager errors out if resume_if_exists is True and no checkpoint could be found. This behaviour can be disabled, in which case exp_manager will print a message and continue without restoring, by setting resume_ignore_no_checkpoint to True. Defaults to False. - create_tensorboard_logger (bool): Whether to create a tensorboard logger and attach it to the pytorch lightning trainer. Defaults to True. - summary_writer_kwargs (dict): A dictionary of kwargs that can be passed to lightning's TensorboardLogger class. Note that log_dir is passed by exp_manager and cannot exist in this dict. Defaults to None. - create_wandb_logger (bool): Whether to create a Weights and Baises logger and attach it to the pytorch lightning trainer. Defaults to False. - wandb_logger_kwargs (dict): A dictionary of kwargs that can be passed to lightning's WandBLogger class. Note that name and project are required parameters if create_wandb_logger is True. Defaults to None. - create_checkpoint_callback (bool): Whether to create a ModelCheckpoint callback and attach it to the pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent checkpoint under *last.ckpt, and the final checkpoint after training completes under *end.ckpt. Defaults to True. - files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which copies no files. returns: log_dir (Path): The final logging directory where logging files are saved. Usually the concatenation of exp_dir, name, and version. """ # Add rank information to logger # Note: trainer.global_rank and trainer.is_global_zero are not set until trainer.fit, so have to hack around it global_rank = trainer.node_rank * trainer.num_gpus + int( os.environ.get("LOCAL_RANK", 0)) logging.rank = global_rank if cfg is None: logging.error( "exp_manager did not receive a cfg argument. It will be disabled.") return if trainer.fast_dev_run: logging.info( "Trainer was called with fast_dev_run. exp_manager will return without any functionality." ) return # Ensure passed cfg is compliant with ExpManagerConfig schema = OmegaConf.structured(ExpManagerConfig) if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): raise ValueError( f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig" ) cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) cfg = OmegaConf.merge(schema, cfg) error_checks( trainer, cfg ) # Ensures that trainer options are compliant with NeMo and exp_manager arguments log_dir, exp_dir, name, version = get_log_dir( trainer=trainer, exp_dir=cfg.exp_dir, name=cfg.name, version=cfg.version, explicit_log_dir=cfg.explicit_log_dir, use_datetime_version=cfg.use_datetime_version, ) if cfg.resume_if_exists: check_resume(trainer, log_dir, cfg.resume_past_end, cfg.resume_ignore_no_checkpoint) checkpoint_name = name # If name returned from get_log_dir is "", use cfg.name for checkpointing if checkpoint_name is None or checkpoint_name == '': checkpoint_name = cfg.name or "default" cfg.name = name # Used for configure_loggers so that the log_dir is properly set even if name is "" cfg.version = version # update app_state with log_dir, exp_dir, etc app_state = AppState() app_state.log_dir = log_dir app_state.exp_dir = exp_dir app_state.name = name app_state.version = version app_state.checkpoint_name = checkpoint_name app_state.create_checkpoint_callback = cfg.create_checkpoint_callback app_state.checkpoint_callback_params = cfg.checkpoint_callback_params # Create the logging directory if it does not exist os.makedirs( log_dir, exist_ok=True ) # Cannot limit creation to global zero as all ranks write to own log file logging.info(f'Experiments will be logged at {log_dir}') trainer._default_root_dir = log_dir # Handle Loggers by creating file and handle DEBUG statements log_file = log_dir / f'nemo_log_globalrank-{global_rank}_localrank-{int(os.environ.get("LOCAL_RANK", 0))}.txt' logging.add_file_handler(log_file) # For some reason, LearningRateLogger requires trainer to have a logger. Safer to create logger on all ranks # not just global rank 0. if cfg.create_tensorboard_logger or cfg.create_wandb_logger: configure_loggers( trainer, exp_dir, cfg.name, cfg.version, cfg.create_tensorboard_logger, cfg.summary_writer_kwargs, cfg.create_wandb_logger, cfg.wandb_logger_kwargs, ) if cfg.create_checkpoint_callback: configure_checkpointing(trainer, log_dir, checkpoint_name, cfg.checkpoint_callback_params) if is_global_rank_zero(): # Move files_to_copy to folder and add git information if present if cfg.files_to_copy: for _file in cfg.files_to_copy: copy(Path(_file), log_dir) # Create files for cmd args and git info with open(log_dir / 'cmd-args.log', 'w') as _file: _file.write(" ".join(sys.argv)) # Try to get git hash git_repo, git_hash = get_git_hash() if git_repo: with open(log_dir / 'git-info.log', 'w') as _file: _file.write(f'commit hash: {git_hash}') _file.write(get_git_diff()) # Add err_file logging to global_rank zero logging.add_err_file_handler(log_dir / 'nemo_error_log.txt') # Add lightning file logging to global_rank zero add_filehandlers_to_pl_logger(log_dir / 'lightning_logs.txt', log_dir / 'nemo_error_log.txt') return log_dir
), pytest.param( Expected( create=lambda: create_readonly({"foo1": "bar"}), op=lambda cfg: cfg.merge_with({"foo2": "bar"}), exception_type=ReadonlyConfigError, key="foo2", msg="Cannot change read-only config container", ), id="dict,readonly:merge_with", ), pytest.param( Expected( create=lambda: OmegaConf.structured(ConcretePlugin), op=lambda cfg: OmegaConf.merge(cfg, {"params": { "foo": "bar" }}), exception_type=ValidationError, msg="Value 'bar' could not be converted to Integer", key="foo", full_key="params.foo", object_type=ConcretePlugin.FoobarParams, child_node=lambda cfg: cfg.params.foo, parent_node=lambda cfg: cfg.params, ), id="structured:merge,invalid_field_type", ), pytest.param( Expected( create=lambda: OmegaConf.structured(ConcretePlugin), op=lambda cfg: OmegaConf.merge(cfg, {"params": {
def load_configuration( self, config_name: Optional[str], overrides: List[str], strict: Optional[bool] = None, ) -> DictConfig: assert config_name is None or isinstance(config_name, str) assert strict is None or isinstance(strict, bool) assert isinstance(overrides, list) if strict is None: strict = self.default_strict assert overrides is None or isinstance(overrides, list) overrides = copy.deepcopy(overrides) or [] if config_name is not None and not self.exists_in_search_path( config_name): # TODO: handle schema as a special case descs = [ f"\t{src.path} (from {src.provider})" for src in self.repository.get_sources() ] lines = "\n".join(descs) raise MissingConfigException( missing_cfg_file=config_name, message= f"Cannot find primary config file: {config_name}\nSearch path:\n{lines}", ) # Load hydra config hydra_cfg, _load_trace = self._create_cfg(cfg_filename="hydra_config") # Load job config job_cfg, job_cfg_load_trace = self._create_cfg( cfg_filename=config_name, record_load=False) job_defaults = ConfigLoaderImpl._get_defaults(job_cfg) defaults = ConfigLoaderImpl._get_defaults(hydra_cfg) job_cfg_type = OmegaConf.get_type(job_cfg) if job_cfg_type is not None and not issubclass(job_cfg_type, dict): hydra_cfg._promote(job_cfg_type) # this is breaking encapsulation a bit. can potentially be implemented in OmegaConf hydra_cfg._metadata.ref_type = job_cfg._metadata.ref_type # if defaults are re-introduced by the promotion, remove it. if "defaults" in hydra_cfg: with open_dict(hydra_cfg): del hydra_cfg["defaults"] if config_name is not None: defaults.append("__SELF__") split_at = len(defaults) ConfigLoaderImpl._merge_default_lists(defaults, job_defaults) consumed = self._apply_defaults_overrides(overrides, defaults) consumed_free_job_defaults = self._apply_free_defaults( defaults, overrides) ConfigLoaderImpl._validate_defaults(defaults) # Load and defaults and merge them into cfg cfg = self._merge_defaults(hydra_cfg, job_cfg, job_cfg_load_trace, defaults, split_at) OmegaConf.set_struct(cfg.hydra, True) OmegaConf.set_struct(cfg, strict) # Merge all command line overrides after enabling strict flag all_consumed = consumed + consumed_free_job_defaults remaining_overrides = [x for x in overrides if x not in all_consumed] try: merged = OmegaConf.merge( cfg, OmegaConf.from_dotlist(remaining_overrides)) assert isinstance(merged, DictConfig) cfg = merged except OmegaConfBaseException as ex: raise HydraException("Error merging overrides") from ex remaining = consumed + consumed_free_job_defaults + remaining_overrides def is_hydra(x: str) -> bool: return x.startswith("hydra.") or x.startswith("hydra/") cfg.hydra.overrides.task = [x for x in remaining if not is_hydra(x)] cfg.hydra.overrides.hydra = [x for x in remaining if is_hydra(x)] with open_dict(cfg.hydra.job): if "name" not in cfg.hydra.job: cfg.hydra.job.name = JobRuntime().get("name") cfg.hydra.job.override_dirname = get_overrides_dirname( input_list=cfg.hydra.overrides.task, kv_sep=cfg.hydra.job.config.override_dirname.kv_sep, item_sep=cfg.hydra.job.config.override_dirname.item_sep, exclude_keys=cfg.hydra.job.config.override_dirname. exclude_keys, ) cfg.hydra.job.config_name = config_name for key in cfg.hydra.job.env_copy: cfg.hydra.job.env_set[key] = os.environ[key] return cfg
def test_merge_missing_object_onto_typed_dictconfig(self, module: Any) -> None: c1 = OmegaConf.structured(module.DictOfObjects) c2 = OmegaConf.merge(c1, {"users": {"bob": "???"}}) assert isinstance(c2, DictConfig) assert OmegaConf.is_missing(c2.users, "bob")
def __init__(self, cfg: DictConfig, trainer: Trainer = None): if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) self._parser = parsers.make_parser( labels=cfg.labels, name='en', unk_id=-1, blank_id=-1, do_normalize=True, abbreviation_version="fastpitch", make_table=False, ) super().__init__(cfg=cfg, trainer=trainer) schema = OmegaConf.structured(FastPitchHifiGanE2EConfig) # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): raise ValueError( f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig" ) # Ensure passed cfg is compliant with schema OmegaConf.merge(cfg, schema) self.preprocessor = instantiate(cfg.preprocessor) self.melspec_fn = instantiate(cfg.preprocessor, highfreq=None, use_grads=True) self.encoder = instantiate(cfg.input_fft) self.duration_predictor = instantiate(cfg.duration_predictor) self.pitch_predictor = instantiate(cfg.pitch_predictor) self.generator = instantiate(cfg.generator) self.multiperioddisc = MultiPeriodDiscriminator() self.multiscaledisc = MultiScaleDiscriminator() self.mel_val_loss = L1MelLoss() self.feat_matching_loss = FeatureMatchingLoss() self.disc_loss = DiscriminatorLoss() self.gen_loss = GeneratorLoss() self.max_token_duration = cfg.max_token_duration self.pitch_emb = torch.nn.Conv1d( 1, cfg.symbols_embedding_dim, kernel_size=cfg.pitch_embedding_kernel_size, padding=int((cfg.pitch_embedding_kernel_size - 1) / 2), ) # Store values precomputed from training data for convenience self.register_buffer('pitch_mean', torch.zeros(1)) self.register_buffer('pitch_std', torch.zeros(1)) self.loss = BaseFastPitchLoss() self.mel_loss_coeff = cfg.mel_loss_coeff self.log_train_images = False self.logged_real_samples = False self._tb_logger = None self.hann_window = None self.splice_length = cfg.splice_length self.sample_rate = cfg.sample_rate self.hop_size = cfg.hop_size
def test_merge_into_missing_sc(self, module: Any) -> None: c1 = OmegaConf.structured(module.PluginHolder) c2 = OmegaConf.merge(c1, {"plugin": "???"}) assert c2.plugin == module.Plugin()
def test_3way_dict_merge() -> None: c1 = OmegaConf.create("{a: 1, b: 2}") c2 = OmegaConf.create("{b: 3}") c3 = OmegaConf.create("{a: 2, c: 3}") c4 = OmegaConf.merge(c1, c2, c3) assert {"a": 2, "b": 3, "c": 3} == c4
def test_plugin_merge(self, module: Any) -> None: plugin = OmegaConf.structured(module.Plugin) concrete = OmegaConf.structured(module.ConcretePlugin) ret = OmegaConf.merge(plugin, concrete) assert ret == concrete assert OmegaConf.get_type(ret) == module.ConcretePlugin
def test_with_readonly(c1: Any, c2: Any) -> None: cfg = OmegaConf.create(c1) OmegaConf.set_readonly(cfg, True) cfg2 = OmegaConf.merge(cfg, c2) assert OmegaConf.is_readonly(cfg2)
def test_merge_of_non_subclass_1(self, module: Any) -> None: cfg1 = OmegaConf.create({"plugin": module.Plugin}) cfg2 = OmegaConf.create({"plugin": module.FaultyPlugin}) with raises(ValidationError): OmegaConf.merge(cfg1, cfg2)
def test_merged_with_nons_subclass(self, class_type: str) -> None: module: Any = import_module(class_type) c1 = OmegaConf.structured(module.Plugin) c2 = OmegaConf.structured(module.FaultyPlugin) with pytest.raises(ValidationError): OmegaConf.merge(c1, c2)
def test_merge_error_new_attribute(self, module: Any) -> None: cfg = OmegaConf.structured(module.ConcretePlugin) cfg2 = OmegaConf.create({"params": {"bar": 10}}) # raise if an invalid key is merged into a struct with raises(ConfigKeyError): OmegaConf.merge(cfg, cfg2)
def test_merge_user_list_with_wrong_key(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.UserList) with pytest.raises(ConfigKeyError): OmegaConf.merge(cfg, {"list": [{"foo": "var"}]})
def test_merge_error_override_bad_type(self, module: Any) -> None: cfg = OmegaConf.structured(module.ConcretePlugin) # raise if an invalid key is merged into a struct with raises(ValidationError): OmegaConf.merge(cfg, {"params": {"foo": "zonk"}})
def test_merge_dict_with_wrong_type(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.UserDict) with pytest.raises(ValidationError): OmegaConf.merge(cfg, {"dict": {"foo": "var"}})
def _load_config_impl( self, input_file: str, record_load: bool = True ) -> Tuple[Optional[DictConfig], Optional[LoadTrace]]: """ :param input_file: :param record_load: :return: the loaded config or None if it was not found """ def record_loading( name: str, path: Optional[str], provider: Optional[str], schema_provider: Optional[str], ) -> Optional[LoadTrace]: trace = LoadTrace( filename=name, path=path, provider=provider, schema_provider=schema_provider, ) if record_load: self.all_config_checked.append(trace) return trace ret = self.repository.load_config(config_path=input_file) if ret is not None: if not isinstance(ret.config, DictConfig): raise ValueError( f"Config {input_file} must be a Dictionary, got {type(ret).__name__}" ) if not ret.is_schema_source: try: schema = ConfigStore.instance().load( config_path=ConfigSource._normalize_file_name( filename=input_file)) merged = OmegaConf.merge(schema.node, ret.config) assert isinstance(merged, DictConfig) return ( merged, record_loading( name=input_file, path=ret.path, provider=ret.provider, schema_provider=schema.provider, ), ) except ConfigLoadError: # schema not found, ignore pass return ( ret.config, record_loading( name=input_file, path=ret.path, provider=ret.provider, schema_provider=None, ), ) else: return ( None, record_loading(name=input_file, path=None, provider=None, schema_provider=None), )
def initialize(self, config_path: str, force: bool = False) -> OmegaConf: """ Args: config_path: a file or dir force: ignore if initialized Returns: user_config: only return the user config """ if self.initialized and not force: raise ValueError("FlyConfig is already initialized!") config_path = check_config_path(config_path) init_omegaconf() system_config = load_system_config() user_config = load_user_config(config_path) config = OmegaConf.merge(system_config, user_config) # get current working dir config.flyconfig.runtime.cwd = os.getcwd() # change working dir if not self.disable_chdir: working_dir_path = config.flyconfig.run.dir os.makedirs(working_dir_path, exist_ok=True) os.chdir(working_dir_path) # configure logging if int(os.environ.get("LOCAL_RANK", 0)) == 0 and not self.disable_logging: logging.config.dictConfig( OmegaConf.to_container(config.flyconfig.logging)) logger.info("FlyConfig Initialized") if not self.disable_chdir: logger.info( f"Working directory is changed to {working_dir_path}") # clean defaults del config["defaults"] # overrides overrides = get_overrides_from_argv(sys.argv[1:]) overrides_config = OmegaConf.from_dotlist(overrides) config = OmegaConf.merge(config, overrides_config) # get system config self.system_config = OmegaConf.create( {"flyconfig": OmegaConf.to_container(config.flyconfig)}) # get user config self.user_config = copy.deepcopy(config) del self.user_config["flyconfig"] # save config if int(os.environ.get("LOCAL_RANK", 0)) == 0 and not self.disable_logging: os.makedirs(self.system_config.flyconfig.output_subdir, exist_ok=True) # save the entire config directory cwd = self.system_config.flyconfig.runtime.cwd dirpath = os.path.join(cwd, os.path.dirname(config_path)) shutil.copytree( dirpath, os.path.join(self.system_config.flyconfig.output_subdir, "config")) # save system config _save_config(filepath=os.path.join( self.system_config.flyconfig.output_subdir, "flyconfig.yml"), config=self.system_config) # save user config _save_config(filepath=os.path.join( self.system_config.flyconfig.output_subdir, "config.yml"), config=self.user_config) logger.info("\n\nConfiguration:\n" + self.user_config.pretty()) self.initialized = True return self.user_config
from hydra.utils import get_class, instantiate from omegaconf import OmegaConf from torch import Tensor cfg = {} full_class = f"gen.configen_tests.utils.data.dataset.TensorDatasetConf" schema = OmegaConf.structured(get_class(full_class)) cfg = OmegaConf.merge(schema, cfg) obj = instantiate(cfg, tensors=(Tensor([1]))) print(obj)