def test_get_config_header(cfg_text: str, expected: Any, sep: str) -> None: cfg_text = cfg_text.format(sep=sep) if isinstance(expected, dict): header = ConfigSource._get_header_dict(cfg_text) assert header == expected else: with expected: ConfigSource._get_header_dict(cfg_text)
def load_config( self, config_path: str, is_primary_config: bool, package_override: Optional[str] = None, ) -> ConfigResult: normalized_config_path = self._normalize_file_name(config_path) res = importlib_resources.files(self.path).joinpath(normalized_config_path) if not res.exists(): raise ConfigLoadError(f"Config not found : {normalized_config_path}") with open(res, encoding="utf-8") as f: header_text = f.read(512) header = ConfigSource._get_header_dict(header_text) self._update_package_in_header( header=header, normalized_config_path=normalized_config_path, is_primary_config=is_primary_config, package_override=package_override, ) f.seek(0) cfg = OmegaConf.load(f) defaults_list = self._extract_defaults_list( config_path=config_path, cfg=cfg ) return ConfigResult( config=self._embed_config(cfg, header["package"]), path=f"{self.scheme()}://{self.path}", provider=self.provider, header=header, defaults_list=defaults_list, )
def load_config( self, config_path: str, is_primary_config: bool, package_override: Optional[str] = None, ) -> ConfigResult: normalized_config_path = self._normalize_file_name(config_path) full_path = os.path.realpath( os.path.join(self.path, normalized_config_path)) if not os.path.exists(full_path): raise ConfigLoadError(f"Config not found : {full_path}") with open(full_path) as f: header_text = f.read(512) header = ConfigSource._get_header_dict(header_text) self._update_package_in_header( header=header, normalized_config_path=normalized_config_path, is_primary_config=is_primary_config, package_override=package_override, ) f.seek(0) cfg = OmegaConf.load(f) return ConfigResult( config=self._embed_config(cfg, header["package"]), path=f"{self.scheme()}://{self.path}", provider=self.provider, header=header, )
def load_config( self, config_path: str, is_primary_config: bool, package_override: Optional[str] = None, ) -> ConfigResult: normalized_config_path = self._normalize_file_name( filename=config_path) module_name, resource_name = PackageConfigSource._split_module_and_resource( self.concat(self.path, normalized_config_path)) try: with resource_stream(module_name, resource_name) as stream: header_text = stream.read(512) header = ConfigSource._get_header_dict(header_text.decode()) self._update_package_in_header( header=header, normalized_config_path=normalized_config_path, is_primary_config=is_primary_config, package_override=package_override, ) stream.seek(0) cfg = OmegaConf.load(stream) return ConfigResult( config=self._embed_config(cfg, header["package"]), path=f"{self.scheme()}://{self.path}", provider=self.provider, header=header, ) except FileNotFoundError: raise ConfigLoadError( f"Config not found: module={module_name}, resource_name={resource_name}" )
def load_config(self, config_path: str) -> ConfigResult: normalized_config_path = self._normalize_file_name(config_path) res = resources.files(self.path).joinpath(normalized_config_path) # type:ignore if not res.exists(): raise ConfigLoadError(f"Config not found : {normalized_config_path}") with res.open(encoding="utf-8") as f: header_text = f.read(512) header = ConfigSource._get_header_dict(header_text) f.seek(0) cfg = OmegaConf.load(f) return ConfigResult( config=cfg, path=f"{self.scheme()}://{self.path}", provider=self.provider, header=header, )
def load_config(self, config_path: str) -> ConfigResult: normalized_config_path = self._normalize_file_name(config_path) full_path = os.path.realpath(os.path.join(self.path, normalized_config_path)) if not os.path.exists(full_path): raise ConfigLoadError(f"Config not found : {full_path}") with open(full_path, encoding="utf-8") as f: header_text = f.read(512) header = ConfigSource._get_header_dict(header_text) f.seek(0) cfg = OmegaConf.load(f) return ConfigResult( config=cfg, path=f"{self.scheme()}://{self.path}", provider=self.provider, header=header, )
def _read_config(self, res: Any) -> ConfigResult: try: if sys.version_info[0:2] >= (3, 8) and isinstance( res, zipfile.Path): # zipfile does not support encoding, read() calls returns bytes. f = res.open() else: f = res.open(encoding="utf-8") header_text = f.read(512) if isinstance(header_text, bytes): # if header is bytes, utf-8 decode (zipfile path) header_text = header_text.decode("utf-8") header = ConfigSource._get_header_dict(header_text) f.seek(0) cfg = OmegaConf.load(f) return ConfigResult( config=cfg, path=f"{self.scheme()}://{self.path}", provider=self.provider, header=header, ) finally: f.close()