def parse_flags(flag_list, arg_parser: argparse.ArgumentParser, args_preload_func=_args_preload_from_config_files): """ Parses flags from argument parser. Args: flag_list: A list of flags. arg_parser: The program argument parser. args_preload_func: A callable function for pre-loading arguments, maybe from config file, hyper parameter set. """ program_parsed_args, remaining_argv = arg_parser.parse_known_args() cfg_file_args = {} if args_preload_func is not None: cfg_file_args = args_preload_func(program_parsed_args) program_parsed_args = yaml_load_checking(program_parsed_args.__dict__) top_program_parsed_args = {} for f in flag_list: flag_key = f.name if isinstance(f, ModuleFlag): flag_key = f.cls_key top_program_parsed_args[f.params_key] = {} if program_parsed_args.get(f.params_key, None) is not None: top_program_parsed_args[f.params_key] = program_parsed_args[f.params_key] if f.params_key in cfg_file_args: top_program_parsed_args[f.params_key] = deep_merge_dict( cfg_file_args[f.params_key], top_program_parsed_args[f.params_key]) if program_parsed_args.get(flag_key, None) is not None: top_program_parsed_args[flag_key] = program_parsed_args[flag_key] elif flag_key in cfg_file_args: top_program_parsed_args[flag_key] = cfg_file_args[flag_key] else: top_program_parsed_args[flag_key] = f.default return top_program_parsed_args, remaining_argv
def _pre_load_args(args): cfg_file_args = yaml_load_checking( load_from_config_path( flatten_string_list( getattr(args, flags_core.DEFAULT_CONFIG_FLAG.name)))) model_dirs = flatten_string_list(args.model_dir or cfg_file_args.get("model_dir", None)) hparams_set = args.hparams_set if hparams_set is None: hparams_set = cfg_file_args.get("hparams_set", None) predefined_parameters = get_hyper_parameters(hparams_set) formatted_parameters = {} if "model.class" in predefined_parameters: formatted_parameters["model.class"] = predefined_parameters.pop( "model.class") if "model" in predefined_parameters: formatted_parameters["model"] = predefined_parameters.pop("model") if "model.params" in predefined_parameters: formatted_parameters["model.params"] = predefined_parameters.pop( "model.params") if len(predefined_parameters) > 0: formatted_parameters["entry.params"] = predefined_parameters try: model_cfgs = ModelConfigs.load(model_dirs[0]) return deep_merge_dict( deep_merge_dict(model_cfgs, formatted_parameters), cfg_file_args) except Exception: return deep_merge_dict(formatted_parameters, cfg_file_args)
def extend_define_and_parse(flag_name, args, remaining_argv, backend="tf"): f = _DEFINED_FLAGS.get(flag_name, None) if f is None or not isinstance(f, ModuleFlag): return args if not hasattr(REGISTRIES[backend][f.module_name][args[f.cls_key]], "class_or_method_args"): return args arg_parser = argparse.ArgumentParser() for ff in REGISTRIES[backend][f.module_name][args[f.cls_key]].class_or_method_args(): if isinstance(ff, ModuleFlag): if args[f.params_key].get(ff.cls_key, None): this_cls = REGISTRIES[backend][ff.module_name][args[f.params_key][ff.cls_key]] if hasattr(this_cls, "class_or_method_args"): for fff in this_cls.class_or_method_args(): fff.define(arg_parser) parsed_args, remaining_argv = arg_parser.parse_known_args(remaining_argv) parsed_args = yaml_load_checking(parsed_args.__dict__) for ff in REGISTRIES[backend][f.module_name][args[f.cls_key]].class_or_method_args(): if isinstance(ff, ModuleFlag): if args[f.params_key].get(ff.cls_key, None): this_cls = REGISTRIES[backend][ff.module_name][args[f.params_key][ff.cls_key]] if hasattr(this_cls, "class_or_method_args"): if args[f.params_key].get(ff.params_key, None) is None: args[f.params_key][ff.params_key] = {} for fff in this_cls.class_or_method_args(): flag_key = fff.name if isinstance(fff, ModuleFlag): flag_key = fff.cls_key if parsed_args[flag_key] is not None: args[f.params_key][ff.params_key][flag_key] = parsed_args[flag_key] args.pop(flag_key, None) args.pop(fff.name, None) elif flag_key in args: args[f.params_key][ff.params_key][flag_key] = args.pop(flag_key) args.pop(fff.name, None) elif fff.name in args: args[f.params_key][ff.params_key][flag_key] = args.pop(fff.name) elif fff.name in args[f.params_key][ff.params_key]: if flag_name not in args[f.params_key][ff.params_key]: args[f.params_key][ff.params_key][flag_key] = args[f.params_key][ff.params_key].pop( fff.name) if isinstance(fff, ModuleFlag): args[f.params_key][ff.params_key][fff.params_key] = deep_merge_dict( args[f.params_key][ff.params_key].get(fff.params_key, {}) or {}, deep_merge_dict(args.get(fff.params_key, {}) or {}, parsed_args.get(fff.params_key, {}) or {})) return args, remaining_argv
def intelligent_parse_flags(flag_list, arg_parser: argparse.ArgumentParser, args_preload_func=_args_preload_from_config_files, backend="tf"): """ Parses flags from argument parser. Args: flag_list: A list of flags. arg_parser: The program argument parser. args_preload_func: A callable function for pre-loading arguments, maybe from config file, hyper parameter set. backend: The DL backend. """ program_parsed_args, remaining_argv = arg_parser.parse_known_args() cfg_file_args = {} if args_preload_func is not None: cfg_file_args = args_preload_func(program_parsed_args) top_program_parsed_args = _flatten_args(flag_list, yaml_load_checking(program_parsed_args.__dict__)) for f in flag_list: if isinstance(f, ModuleFlag): if f.cls_key in top_program_parsed_args and top_program_parsed_args[f.cls_key]: cfg_file_args[f.cls_key] = top_program_parsed_args[f.cls_key] cfg_file_args = _flatten_args(flag_list, cfg_file_args) for f in flag_list: if isinstance(f, Flag): if top_program_parsed_args[f.name] is None: top_program_parsed_args[f.name] = cfg_file_args.get(f.name, None) cfg_file_args.pop(f.name, None) else: submodule_cls = (top_program_parsed_args.get(f.cls_key, None) or cfg_file_args.get(f.cls_key, None)) cfg_file_args.pop(f.cls_key, None) if submodule_cls is None: continue top_program_parsed_args[f.cls_key] = submodule_cls if top_program_parsed_args.get(f.params_key, None) is None: top_program_parsed_args[f.params_key] = {} module_arg_parser = get_argparser(f.module_name, submodule_cls) module_parsed_args, remaining_argv = module_arg_parser.parse_known_args(remaining_argv) module_parsed_args = yaml_load_checking(module_parsed_args.__dict__) if hasattr(REGISTRIES[backend][f.module_name][submodule_cls], "class_or_method_args"): key_cfg_file_args = _flatten_args( REGISTRIES[backend][f.module_name][submodule_cls].class_or_method_args(), cfg_file_args) for inner_f in REGISTRIES[backend][f.module_name][submodule_cls].class_or_method_args(): flag_key = inner_f.name if isinstance(inner_f, ModuleFlag): flag_key = inner_f.cls_key cfg_file_args.pop(flag_key, None) if module_parsed_args[flag_key] is not None: top_program_parsed_args[f.params_key][flag_key] = module_parsed_args[flag_key] top_program_parsed_args.pop(flag_key, None) key_cfg_file_args.pop(flag_key, None) cfg_file_args.pop(flag_key, None) elif flag_key in top_program_parsed_args: top_program_parsed_args[f.params_key][flag_key] = top_program_parsed_args.pop(flag_key) key_cfg_file_args.pop(flag_key, None) cfg_file_args.pop(flag_key, None) elif flag_key in key_cfg_file_args: top_program_parsed_args[f.params_key][flag_key] = key_cfg_file_args.pop(flag_key) cfg_file_args.pop(flag_key, None) if isinstance(inner_f, ModuleFlag): top_program_parsed_args[f.params_key][inner_f.params_key] = deep_merge_dict( cfg_file_args.pop(inner_f.params_key, {}) or {}, deep_merge_dict(top_program_parsed_args[f.params_key].pop(inner_f.params_key, {}) or {}, deep_merge_dict(top_program_parsed_args.pop(inner_f.params_key, {}) or {}, module_parsed_args.pop(inner_f.params_key, {}) or {}))) top_program_parsed_args = deep_merge_dict(cfg_file_args, top_program_parsed_args) for f in flag_list: if isinstance(f, Flag): if f.name not in top_program_parsed_args or top_program_parsed_args[f.name] is None: top_program_parsed_args[f.name] = f.default return top_program_parsed_args, remaining_argv
def _args_preload_from_config_files(args): cfg_file_args = yaml_load_checking(load_from_config_path( flatten_string_list(getattr(args, DEFAULT_CONFIG_FLAG.name, None)))) return cfg_file_args