def configure_log(log_config: DictConfig, verbose_config: Union[bool, str, Sequence[str]]) -> None: assert isinstance(verbose_config, (bool, str)) or OmegaConf.is_list(verbose_config) if log_config is not None: conf: Dict[str, Any] = OmegaConf.to_container( # type: ignore log_config, resolve=True) logging.config.dictConfig(conf) else: # default logging to stdout root = logging.getLogger() root.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter( "[%(asctime)s][%(name)s][%(levelname)s] - %(message)s") handler.setFormatter(formatter) root.addHandler(handler) if isinstance(verbose_config, bool): if verbose_config: logging.getLogger().setLevel(logging.DEBUG) else: if isinstance(verbose_config, str): verbose_list = OmegaConf.create([verbose_config]) elif OmegaConf.is_list(verbose_config): verbose_list = verbose_config # type: ignore else: assert False for logger in verbose_list: logging.getLogger(logger).setLevel(logging.DEBUG)
def check_if_validation(cfg: DictConfig) -> bool: flag = False flag |= OmegaConf.is_list(cfg.training.lr) flag |= OmegaConf.is_list(cfg.training.regularization) if check_if_gev_loss(get_loss_class(cfg)): flag |= OmegaConf.is_list(cfg.loss.xi) return flag
def _get_kwargs( config: Union[DictConfig, ListConfig], **kwargs: Any, ) -> Any: from hydra.utils import instantiate assert OmegaConf.is_config(config) if OmegaConf.is_list(config): assert isinstance(config, ListConfig) return [ _get_kwargs(x) if OmegaConf.is_config(x) else x for x in config ] assert OmegaConf.is_dict( config), "Input config is not an OmegaConf DictConfig" final_kwargs = {} recursive = _is_recursive(config, kwargs) overrides = OmegaConf.create(kwargs, flags={"allow_objects": True}) config.merge_with(overrides) for k, v in config.items(): final_kwargs[k] = v if recursive: for k, v in final_kwargs.items(): if _is_target(v): final_kwargs[k] = instantiate(v) elif OmegaConf.is_dict(v) and not OmegaConf.is_none(v): d = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in v.items(): if _is_target(value): d[key] = instantiate(value) elif OmegaConf.is_config(value): d[key] = _get_kwargs(value) else: d[key] = value final_kwargs[k] = d elif OmegaConf.is_list(v): lst = OmegaConf.create([], flags={"allow_objects": True}) for x in v: if _is_target(x): lst.append(instantiate(x)) elif OmegaConf.is_config(x): lst.append(_get_kwargs(x)) else: lst.append(x) final_kwargs[k] = lst else: if OmegaConf.is_none(v): v = None final_kwargs[k] = v return final_kwargs
def __init__(self, fn): """ load a yaml config of a job and save generated configs as yaml for each task. return: a list of files to run as specified by `run_task`. """ if fn.endswith(".py"): # a python command. self.backend = "python" self.run_yamls = [fn] return job_config = recursive_config(fn) if job_config.base_dir is None: # single file job config. self.run_yamls = [fn] return self.project_dir = os.path.join("projects", job_config.project_dir) self.run_dir = os.path.join("runs", job_config.project_dir) if job_config.run_task is not None: run_yamls = [] for stage in job_config.run_task: # each stage can have multiple tasks running in parallel. if OmegaConf.is_list(stage): stage_yamls = [] for task_file in stage: stage_yamls.append( os.path.join(self.project_dir, task_file)) run_yamls.append(stage_yamls) else: run_yamls.append(os.path.join(self.project_dir, stage)) self.run_yamls = run_yamls configs_to_save = self._overwrite_task(job_config) self._save_configs(configs_to_save)
def _recurse_on_config(self, config): if OmegaConf.is_list( config) and len(config) > 0 and "url" in config[0]: # Found the urls, let's test them for item in config: # First try making the DownloadableFile class to make sure # everything is fine download = DownloadableFile(**item) # Now check the actual header 3 times before failing for i in range(3): try: check_header(download._url, from_google=download._from_google) break except AssertionError: if i == 2: raise else: # If failed, add a sleep of 5 seconds before retrying time.sleep(2) elif OmegaConf.is_dict(config): # Both version and resources should be present if "version" in config: self.assertIn("resources", config) if "resources" in config: self.assertIn("version", config) # Let's continue recursing for item in config: self._recurse_on_config(config[item])
def __init__(self, config, *args, **kwargs): transform_params = config.transforms assert OmegaConf.is_dict(transform_params) or OmegaConf.is_list( transform_params) if OmegaConf.is_dict(transform_params): transform_params = [transform_params] pytorchvideo_spec = importlib.util.find_spec("pytorchvideo") assert (pytorchvideo_spec is not None ), "Must have pytorchvideo installed to use VideoTransforms" transforms_list = [] for param in transform_params: if OmegaConf.is_dict(param): # This will throw config error if missing transform_type = param.type transform_param = param.get("params", OmegaConf.create({})) else: assert isinstance(param, str), ( "Each transform should either be str or dict containing " "type and params") transform_type = param transform_param = OmegaConf.create([]) transforms_list.append( self.get_transform_object(transform_type, transform_param)) self.transform = img_transforms.Compose(transforms_list)
def _recurse_on_config(self, config: DictConfig, callback: typing.Callable): if OmegaConf.is_list( config) and len(config) > 0 and "url" in config[0]: # Found the urls, let's test them for item in config: # flickr30 download source is down, ignore dataset until a # mirror can be found if getattr(item, "file_name", "") == "flickr30_images.tar.gz": continue # First try making the DownloadableFile class to make sure # everything is fine download = DownloadableFile(**item) # Now, call the actual callback which will test specific scenarios callback(download) elif OmegaConf.is_dict(config): # Both version and resources should be present if "version" in config: self.assertIn("resources", config) if "resources" in config: self.assertIn("version", config) # Let's continue recursing for item in config: self._recurse_on_config(config[item], callback=callback)
def _get_matches(config: Container, word: str) -> List[str]: def str_rep(in_key: Any, in_value: Any) -> str: if OmegaConf.is_config(in_value): return f"{in_key}." else: return f"{in_key}=" if config is None: return [] elif OmegaConf.is_config(config): matches = [] if word.endswith(".") or word.endswith("="): exact_key = word[0:-1] try: conf_node = OmegaConf.select( config, exact_key, throw_on_missing=True ) except MissingMandatoryValue: conf_node = "" if conf_node is not None: if OmegaConf.is_config(conf_node): key_matches = CompletionPlugin._get_matches(conf_node, "") else: # primitive if isinstance(conf_node, bool): conf_node = str(conf_node).lower() key_matches = [conf_node] else: key_matches = [] matches.extend([f"{word}{match}" for match in key_matches]) else: last_dot = word.rfind(".") if last_dot != -1: base_key = word[0:last_dot] partial_key = word[last_dot + 1 :] conf_node = OmegaConf.select(config, base_key) key_matches = CompletionPlugin._get_matches(conf_node, partial_key) matches.extend([f"{base_key}.{match}" for match in key_matches]) else: if isinstance(config, DictConfig): for key, value in config.items_ex(resolve=False): str_key = str(key) if str_key.startswith(word): matches.append(str_rep(key, value)) elif OmegaConf.is_list(config): assert isinstance(config, ListConfig) for idx in range(len(config)): try: value = config[idx] if str(idx).startswith(word): matches.append(str_rep(idx, value)) except MissingMandatoryValue: matches.append(str_rep(idx, "")) else: assert False, f"Object is not an instance of config : {type(config)}" return matches
def _get_matches(config: Container, word: str) -> List[str]: def str_rep(in_key: Union[str, int], in_value: Any) -> str: if OmegaConf.is_config(in_value): return "{}.".format(in_key) else: return "{}=".format(in_key) if config is None: return [] elif OmegaConf.is_config(config): matches = [] if word.endswith(".") or word.endswith("="): exact_key = word[0:-1] try: conf_node = config.select(exact_key) except MissingMandatoryValue: conf_node = "" if conf_node is not None: if OmegaConf.is_config(conf_node): key_matches = CompletionPlugin._get_matches( conf_node, "") else: # primitive if isinstance(conf_node, bool): conf_node = str(conf_node).lower() key_matches = [conf_node] else: key_matches = [] matches.extend( ["{}{}".format(word, match) for match in key_matches]) else: last_dot = word.rfind(".") if last_dot != -1: base_key = word[0:last_dot] partial_key = word[last_dot + 1:] conf_node = config.select(base_key) key_matches = CompletionPlugin._get_matches( conf_node, partial_key) matches.extend([ "{}.{}".format(base_key, match) for match in key_matches ]) else: if isinstance(config, DictConfig): for key, value in config.items_ex(resolve=False): if key.startswith(word): matches.append(str_rep(key, value)) elif OmegaConf.is_list(config): for idx, value in enumerate(config): if str(idx).startswith(word): matches.append(str_rep(idx, value)) else: assert False, "Object is not an instance of config : {}".format( type(config)) return matches
def __init__(self, config, *args, **kwargs): transform_params = config.transforms assert OmegaConf.is_dict(transform_params) or OmegaConf.is_list( transform_params) if OmegaConf.is_dict(transform_params): transform_params = [transform_params] transforms_list = [] for param in transform_params: if OmegaConf.is_dict(param): # This will throw config error if missing transform_type = param.type transform_param = param.get("params", OmegaConf.create({})) else: assert isinstance(param, str), ( "Each transform should either be str or dict containing " + "type and params") transform_type = param transform_param = OmegaConf.create([]) transform = getattr(transforms, transform_type, None) if transform is None: from mmf.utils.env import setup_torchaudio setup_torchaudio() from torchaudio import transforms as torchaudio_transforms transform = getattr(torchaudio_transforms, transform_type, None) # If torchvision or torchaudiodoesn't contain this, check our registry # if we implemented a custom transform as processor if transform is None: transform = registry.get_processor_class(transform_type) assert transform is not None, ( f"transform {transform_type} is not present in torchvision, " + "torchaudio or processor registry") # https://github.com/omry/omegaconf/issues/248 transform_param = OmegaConf.to_container(transform_param) # If a dict, it will be passed as **kwargs, else a list is *args if isinstance(transform_param, collections.abc.Mapping): transform_object = transform(**transform_param) else: transform_object = transform(*transform_param) transforms_list.append(transform_object) self.transform = transforms.Compose(transforms_list)
def _recur_log_omegaconf_param(self, key: str, val: any) -> None: if OmegaConf.is_dict(val): for k, v in val.items(): k: str v: any if key == "": self._recur_log_omegaconf_param(k, v) else: self._recur_log_omegaconf_param(f"{key}.{k}", v) elif OmegaConf.is_list(val): for i, v in enumerate(val): i: int v: any self._recur_log_omegaconf_param(f"{key}[{i}]", v) else: self.log_param(key, str(val))
def configure_optimizers(self): if OmegaConf.is_list(self.optimizer_args): kwargs_dict = {} for kwargs in self.optimizer_args: param_names = kwargs.pop('params') if param_names == 'default': default_kwargs = kwargs else: if isinstance(param_names, str): param_names = [param_names] for param in param_names: kwargs_dict[param] = kwargs optimized_params = [] for n, p in self.model.named_parameters(): for i, (param, kwargs) in enumerate(kwargs_dict.items()): if param in n: optimized_params.append({'params': p, **kwargs}) break elif i == len(kwargs_dict) - 1: optimized_params.append({ 'params': p, }) optimizer = getattr(optimizers, self.optimizer)(optimized_params, **default_kwargs) elif OmegaConf.is_dict(self.optimizer_args): optimizer = getattr(optimizers, self.optimizer)(self.parameters(), **self.optimizer_args) else: raise TypeError if self.scheduler is None: return optimizer else: scheduler = getattr(optimizers, self.scheduler)(optimizer, **self.scheduler_args) return [optimizer], [scheduler]
def _recurse_on_config(self, config): if OmegaConf.is_list( config) and len(config) > 0 and "url" in config[0]: # Found the urls, let's test them for item in config: # First try making the DownloadableFile class to make sure # everything is fine download = DownloadableFile(**item) # Now check the actual header check_header(download._url, from_google=download._from_google) elif OmegaConf.is_dict(config): # Both version and resources should be present if "version" in config: self.assertIn("resources", config) if "resources" in config: self.assertIn("version", config) # Let's continue recursing for item in config: self._recurse_on_config(config[item])
def _recurse_on_config(self, config: DictConfig, callback: typing.Callable): if OmegaConf.is_list( config) and len(config) > 0 and "url" in config[0]: # Found the urls, let's test them for item in config: # First try making the DownloadableFile class to make sure # everything is fine download = DownloadableFile(**item) # Now, call the actual callback which will test specific scenarios callback(download) elif OmegaConf.is_dict(config): # Both version and resources should be present if "version" in config: self.assertIn("resources", config) if "resources" in config: self.assertIn("version", config) # Let's continue recursing for item in config: self._recurse_on_config(config[item], callback=callback)
def test_is_list(cfg: Any, expected: bool) -> None: assert OmegaConf.is_list(cfg) == expected
def __init__(self, config, mode, data_path, label_file): assert mode in ['train', 'val', 'test', 'predict'] self.config = config self.target_type = config.target_type assert self.target_type in ['generative', 'predict'] if self.target_type in ['predict']: # list 前提かな? self.target_cols = config.input_cols self.pred_time = config.pred_time self.window_size = config.window_size self.slide_step = eval(f'config.{mode}.slide_step') self.featrue_cols = [] if 'feature_cols' in config: if OmegaConf.is_list(config.feature_cols): self.feature_cols = OmegaConf.to_container(config.feature_cols) self.data_file_type = config.file_type if not data_path.exists(): raise FileNotFoundError(f'{data_path} does not found.') ''' yaml ファイルで train, val ファイルを直接指定することにする ''' if self.data_file_type == 'file': datafiles = eval(f'config.{mode}.file') if OmegaConf.is_list(datafiles): datafiles = OmegaConf.to_container(datafiles) if type(datafiles) is list: data_file_paths = [] for file in datafiles: data_path_tmp = data_path / file if not data_path_tmp.exists(): data_path_tmp = data_path.parent / file data_file_paths.append(data_path_tmp) self.data_file_paths = [ file.resolve() for file in data_file_paths ] else: data_path_tmp = data_path / datafiles if not data_path_tmp.exists(): data_path = data_path.parent / datafiles else: data_path = data_path_tmp self.data_file_paths = [data_path.resolve()] elif self.data_file_type == 'dir': prefix = '' extension = 'csv' if 'file_prefix' in config: prefix = config.file_prefix if 'file_extension' in config: extension = config.file_extension data_file_paths = data_path.glob(prefix + '*.' + extension) self.data_file_paths = [file.resolve() for file in data_file_paths] """ label file の処理 """ if label_file is not None: self.have_label = True self.labels = self.load_label(label_file) else: self.have_label = False
def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: """ :param config: An config object describing what to call and what params to use. In addition to the parameters, the config must contain: _target_ : target class or callable name (str) And may contain: _args_: List-like of positional arguments to pass to the target _recursive_: Construct nested objects as well (bool). True by default. may be overridden via a _recursive_ key in the kwargs _convert_: Conversion strategy none : Passed objects are DictConfig and ListConfig, default partial : Passed objects are converted to dict and list, with the exception of Structured Configs (and their fields). all : Passed objects are dicts, lists and primitives without a trace of OmegaConf containers _partial_: If True, return functools.partial wrapped method or object False by default. Configure per target. :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present in the config objects are being passed as is to the target. IMPORTANT: dataclasses instances in kwargs are interpreted as config and cannot be used as passthrough :return: if _target_ is a class name: the instantiated object if _target_ is a callable: the return value of the call """ # Return None if config is None if config is None: return None # TargetConf edge case if isinstance(config, TargetConf) and config._target_ == "???": # Specific check to give a good warning about failure to annotate _target_ as a string. raise InstantiationException( dedent(f"""\ Config has missing value for key `_target_`, cannot instantiate. Config type: {type(config).__name__} Check that the `_target_` key in your dataclass is properly annotated and overridden. A common problem is forgetting to annotate _target_ as a string : '_target_: str = ...'""" )) # TODO: print full key if isinstance(config, (dict, list)): config = _prepare_input_dict_or_list(config) kwargs = _prepare_input_dict_or_list(kwargs) # Structured Config always converted first to OmegaConf if is_structured_config(config) or isinstance(config, (dict, list)): config = OmegaConf.structured(config, flags={"allow_objects": True}) if OmegaConf.is_dict(config): # Finalize config (convert targets to strings, merge with kwargs) config_copy = copy.deepcopy(config) config_copy._set_flag(flags=["allow_objects", "struct", "readonly"], values=[True, False, False]) config_copy._set_parent(config._get_parent()) config = config_copy if kwargs: config = OmegaConf.merge(config, kwargs) OmegaConf.resolve(config) _recursive_ = config.pop(_Keys.RECURSIVE, True) _convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE) _partial_ = config.pop(_Keys.PARTIAL, False) return instantiate_node(config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_) elif OmegaConf.is_list(config): # Finalize config (convert targets to strings, merge with kwargs) config_copy = copy.deepcopy(config) config_copy._set_flag(flags=["allow_objects", "struct", "readonly"], values=[True, False, False]) config_copy._set_parent(config._get_parent()) config = config_copy OmegaConf.resolve(config) _recursive_ = kwargs.pop(_Keys.RECURSIVE, True) _convert_ = kwargs.pop(_Keys.CONVERT, ConvertMode.NONE) _partial_ = kwargs.pop(_Keys.PARTIAL, False) if _partial_: raise InstantiationException( "The _partial_ keyword is not compatible with top-level list instantiation" ) return instantiate_node(config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_) else: raise InstantiationException( dedent(f"""\ Cannot instantiate config of type {type(config).__name__}. Top level config must be an OmegaConf DictConfig/ListConfig object, a plain dict/list, or a Structured Config class or instance.""" ))
def test_is_config(cfg: Any, is_conf: bool, is_list: bool, is_dict: bool) -> None: assert OmegaConf.is_config(cfg) == is_conf assert OmegaConf.is_list(cfg) == is_list assert OmegaConf.is_dict(cfg) == is_dict
def test_is_config(cfg: Any, is_conf: bool, is_list: bool, is_dict: bool, type_: Type[Any]) -> None: assert OmegaConf.is_config(cfg) == is_conf assert OmegaConf.is_list(cfg) == is_list assert OmegaConf.is_dict(cfg) == is_dict assert OmegaConf.get_type(cfg) == type_
def _get_kwargs( config: Union[DictConfig, ListConfig], root: bool = True, **kwargs: Any, ) -> Any: from hydra.utils import instantiate assert OmegaConf.is_config(config) if OmegaConf.is_list(config): assert isinstance(config, ListConfig) return [ _get_kwargs(x, root=False) if OmegaConf.is_config(x) else x for x in config ] assert OmegaConf.is_dict( config), "Input config is not an OmegaConf DictConfig" recursive = _is_recursive(config, kwargs) overrides = OmegaConf.create(kwargs, flags={"allow_objects": True}) config.merge_with(overrides) final_kwargs = OmegaConf.create(flags={"allow_objects": True}) final_kwargs._set_parent(config._get_parent()) final_kwargs._set_flag("readonly", False) final_kwargs._set_flag("struct", False) if recursive: for k, v in config.items_ex(resolve=False): if OmegaConf.is_none(v): final_kwargs[k] = v elif _is_target(v): final_kwargs[k] = instantiate(v) elif OmegaConf.is_dict(v): d = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in v.items_ex(resolve=False): if _is_target(value): d[key] = instantiate(value) elif OmegaConf.is_config(value): d[key] = _get_kwargs(value, root=False) else: d[key] = value d._metadata.object_type = v._metadata.object_type final_kwargs[k] = d elif OmegaConf.is_list(v): lst = OmegaConf.create([], flags={"allow_objects": True}) for x in v: if _is_target(x): lst.append(instantiate(x)) elif OmegaConf.is_config(x): lst.append(_get_kwargs(x, root=False)) lst[-1]._metadata.object_type = x._metadata.object_type else: lst.append(x) final_kwargs[k] = lst else: final_kwargs[k] = v else: for k, v in config.items_ex(resolve=False): final_kwargs[k] = v final_kwargs._set_flag("readonly", None) final_kwargs._set_flag("struct", None) final_kwargs._set_flag("allow_objects", None) if not root: # This is tricky, since the root kwargs is exploded anyway we can treat is as an untyped dict # the motivation is that the object type is used as an indicator to treat the object differently during # conversion to a primitive container in some cases final_kwargs._metadata.object_type = config._metadata.object_type return final_kwargs
def _wrap_with_list(x): if OmegaConf.is_list(x): return x else: return [x]
def instantiate_node( node: Any, *args: Any, convert: Union[str, ConvertMode] = ConvertMode.NONE, recursive: bool = True, ) -> Any: # Return None if config is None if node is None or OmegaConf.is_none(node): return None if not OmegaConf.is_config(node): return node # Override parent modes from config if specified if OmegaConf.is_dict(node): # using getitem instead of get(key, default) because OmegaConf will raise an exception # if the key type is incompatible on get. convert = node[_Keys.CONVERT] if _Keys.CONVERT in node else convert recursive = node[ _Keys.RECURSIVE] if _Keys.RECURSIVE in node else recursive if not isinstance(recursive, bool): raise TypeError( f"_recursive_ flag must be a bool, got {type(recursive)}") # If OmegaConf list, create new list of instances if recursive if OmegaConf.is_list(node): items = [ instantiate_node(item, convert=convert, recursive=recursive) for item in node._iter_ex(resolve=True) ] if convert in (ConvertMode.ALL, ConvertMode.PARTIAL): # If ALL or PARTIAL, use plain list as container return items else: # Otherwise, use ListConfig as container lst = OmegaConf.create(items, flags={"allow_objects": True}) lst._set_parent(node) return lst elif OmegaConf.is_dict(node): exclude_keys = set({"_target_", "_convert_", "_recursive_"}) if _is_target(node): target = _resolve_target(node.get(_Keys.TARGET)) kwargs = {} for key, value in node.items_ex(resolve=True): if key not in exclude_keys: if recursive: value = instantiate_node(value, convert=convert, recursive=recursive) kwargs[key] = _convert_node(value, convert) return _call_target(target, *args, **kwargs) else: # If ALL or PARTIAL non structured, instantiate in dict and resolve interpolations eagerly. if convert == ConvertMode.ALL or ( convert == ConvertMode.PARTIAL and node._metadata.object_type is None): dict_items = {} for key, value in node.items_ex(resolve=True): # list items inherits recursive flag from the containing dict. dict_items[key] = instantiate_node(value, convert=convert, recursive=recursive) return dict_items else: # Otherwise use DictConfig and resolve interpolations lazily. cfg = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in node.items_ex(resolve=False): cfg[key] = instantiate_node(value, convert=convert, recursive=recursive) cfg._set_parent(node) cfg._metadata.object_type = node._metadata.object_type return cfg else: assert False, f"Unexpected config type : {type(node).__name__}"
def splitNameConf(conf, search, default_name: str = None): if OmegaConf.is_dict(conf): return getattr(search, conf.pop("name", default_name)), conf elif OmegaConf.is_list(conf): return getattr(search, conf[0]), {} if len(conf) == 1 else conf[1]