Ejemplo n.º 1
0
def parse_flags(flag_list, arg_parser: argparse.ArgumentParser,
                args_preload_func=_args_preload_from_config_files):
    """ Parses flags from argument parser.

    Args:
        flag_list: A list of flags.
        arg_parser: The program argument parser.
        args_preload_func: A callable function for pre-loading arguments, maybe from
            config file, hyper parameter set.
    """
    program_parsed_args, remaining_argv = arg_parser.parse_known_args()
    cfg_file_args = {}
    if args_preload_func is not None:
        cfg_file_args = args_preload_func(program_parsed_args)
    program_parsed_args = yaml_load_checking(program_parsed_args.__dict__)
    top_program_parsed_args = {}
    for f in flag_list:
        flag_key = f.name
        if isinstance(f, ModuleFlag):
            flag_key = f.cls_key
            top_program_parsed_args[f.params_key] = {}
            if program_parsed_args.get(f.params_key, None) is not None:
                top_program_parsed_args[f.params_key] = program_parsed_args[f.params_key]
            if f.params_key in cfg_file_args:
                top_program_parsed_args[f.params_key] = deep_merge_dict(
                    cfg_file_args[f.params_key], top_program_parsed_args[f.params_key])
        if program_parsed_args.get(flag_key, None) is not None:
            top_program_parsed_args[flag_key] = program_parsed_args[flag_key]
        elif flag_key in cfg_file_args:
            top_program_parsed_args[flag_key] = cfg_file_args[flag_key]
        else:
            top_program_parsed_args[flag_key] = f.default

    return top_program_parsed_args, remaining_argv
Ejemplo n.º 2
0
def _pre_load_args(args):
    cfg_file_args = yaml_load_checking(
        load_from_config_path(
            flatten_string_list(
                getattr(args, flags_core.DEFAULT_CONFIG_FLAG.name))))
    model_dirs = flatten_string_list(args.model_dir
                                     or cfg_file_args.get("model_dir", None))
    hparams_set = args.hparams_set
    if hparams_set is None:
        hparams_set = cfg_file_args.get("hparams_set", None)
    predefined_parameters = get_hyper_parameters(hparams_set)
    formatted_parameters = {}
    if "model.class" in predefined_parameters:
        formatted_parameters["model.class"] = predefined_parameters.pop(
            "model.class")
    if "model" in predefined_parameters:
        formatted_parameters["model"] = predefined_parameters.pop("model")
    if "model.params" in predefined_parameters:
        formatted_parameters["model.params"] = predefined_parameters.pop(
            "model.params")
    if len(predefined_parameters) > 0:
        formatted_parameters["entry.params"] = predefined_parameters

    try:
        model_cfgs = ModelConfigs.load(model_dirs[0])
        return deep_merge_dict(
            deep_merge_dict(model_cfgs, formatted_parameters), cfg_file_args)
    except Exception:
        return deep_merge_dict(formatted_parameters, cfg_file_args)
Ejemplo n.º 3
0
def extend_define_and_parse(flag_name, args, remaining_argv, backend="tf"):
    f = _DEFINED_FLAGS.get(flag_name, None)
    if f is None or not isinstance(f, ModuleFlag):
        return args
    if not hasattr(REGISTRIES[backend][f.module_name][args[f.cls_key]], "class_or_method_args"):
        return args
    arg_parser = argparse.ArgumentParser()
    for ff in REGISTRIES[backend][f.module_name][args[f.cls_key]].class_or_method_args():
        if isinstance(ff, ModuleFlag):
            if args[f.params_key].get(ff.cls_key, None):
                this_cls = REGISTRIES[backend][ff.module_name][args[f.params_key][ff.cls_key]]
                if hasattr(this_cls, "class_or_method_args"):
                    for fff in this_cls.class_or_method_args():
                        fff.define(arg_parser)
    parsed_args, remaining_argv = arg_parser.parse_known_args(remaining_argv)
    parsed_args = yaml_load_checking(parsed_args.__dict__)
    for ff in REGISTRIES[backend][f.module_name][args[f.cls_key]].class_or_method_args():
        if isinstance(ff, ModuleFlag):
            if args[f.params_key].get(ff.cls_key, None):
                this_cls = REGISTRIES[backend][ff.module_name][args[f.params_key][ff.cls_key]]
                if hasattr(this_cls, "class_or_method_args"):
                    if args[f.params_key].get(ff.params_key, None) is None:
                        args[f.params_key][ff.params_key] = {}
                    for fff in this_cls.class_or_method_args():
                        flag_key = fff.name
                        if isinstance(fff, ModuleFlag):
                            flag_key = fff.cls_key
                        if parsed_args[flag_key] is not None:
                            args[f.params_key][ff.params_key][flag_key] = parsed_args[flag_key]
                            args.pop(flag_key, None)
                            args.pop(fff.name, None)
                        elif flag_key in args:
                            args[f.params_key][ff.params_key][flag_key] = args.pop(flag_key)
                            args.pop(fff.name, None)
                        elif fff.name in args:
                            args[f.params_key][ff.params_key][flag_key] = args.pop(fff.name)
                        elif fff.name in args[f.params_key][ff.params_key]:
                            if flag_name not in args[f.params_key][ff.params_key]:
                                args[f.params_key][ff.params_key][flag_key] = args[f.params_key][ff.params_key].pop(
                                    fff.name)
                        if isinstance(fff, ModuleFlag):
                            args[f.params_key][ff.params_key][fff.params_key] = deep_merge_dict(
                                args[f.params_key][ff.params_key].get(fff.params_key, {}) or {},
                                deep_merge_dict(args.get(fff.params_key, {}) or {},
                                                parsed_args.get(fff.params_key, {}) or {}))

    return args, remaining_argv
Ejemplo n.º 4
0
def intelligent_parse_flags(flag_list, arg_parser: argparse.ArgumentParser,
                            args_preload_func=_args_preload_from_config_files,
                            backend="tf"):
    """ Parses flags from argument parser.

    Args:
        flag_list: A list of flags.
        arg_parser: The program argument parser.
        args_preload_func: A callable function for pre-loading arguments, maybe from
            config file, hyper parameter set.
        backend: The DL backend.
    """
    program_parsed_args, remaining_argv = arg_parser.parse_known_args()
    cfg_file_args = {}
    if args_preload_func is not None:
        cfg_file_args = args_preload_func(program_parsed_args)
    top_program_parsed_args = _flatten_args(flag_list,
                                            yaml_load_checking(program_parsed_args.__dict__))
    for f in flag_list:
        if isinstance(f, ModuleFlag):
            if f.cls_key in top_program_parsed_args and top_program_parsed_args[f.cls_key]:
                cfg_file_args[f.cls_key] = top_program_parsed_args[f.cls_key]
    cfg_file_args = _flatten_args(flag_list, cfg_file_args)
    for f in flag_list:
        if isinstance(f, Flag):
            if top_program_parsed_args[f.name] is None:
                top_program_parsed_args[f.name] = cfg_file_args.get(f.name, None)
            cfg_file_args.pop(f.name, None)
        else:
            submodule_cls = (top_program_parsed_args.get(f.cls_key, None)
                             or cfg_file_args.get(f.cls_key, None))
            cfg_file_args.pop(f.cls_key, None)
            if submodule_cls is None:
                continue
            top_program_parsed_args[f.cls_key] = submodule_cls
            if top_program_parsed_args.get(f.params_key, None) is None:
                top_program_parsed_args[f.params_key] = {}
            module_arg_parser = get_argparser(f.module_name, submodule_cls)
            module_parsed_args, remaining_argv = module_arg_parser.parse_known_args(remaining_argv)
            module_parsed_args = yaml_load_checking(module_parsed_args.__dict__)

            if hasattr(REGISTRIES[backend][f.module_name][submodule_cls], "class_or_method_args"):
                key_cfg_file_args = _flatten_args(
                    REGISTRIES[backend][f.module_name][submodule_cls].class_or_method_args(), cfg_file_args)
                for inner_f in REGISTRIES[backend][f.module_name][submodule_cls].class_or_method_args():
                    flag_key = inner_f.name
                    if isinstance(inner_f, ModuleFlag):
                        flag_key = inner_f.cls_key
                        cfg_file_args.pop(flag_key, None)
                    if module_parsed_args[flag_key] is not None:
                        top_program_parsed_args[f.params_key][flag_key] = module_parsed_args[flag_key]
                        top_program_parsed_args.pop(flag_key, None)
                        key_cfg_file_args.pop(flag_key, None)
                        cfg_file_args.pop(flag_key, None)
                    elif flag_key in top_program_parsed_args:
                        top_program_parsed_args[f.params_key][flag_key] = top_program_parsed_args.pop(flag_key)
                        key_cfg_file_args.pop(flag_key, None)
                        cfg_file_args.pop(flag_key, None)
                    elif flag_key in key_cfg_file_args:
                        top_program_parsed_args[f.params_key][flag_key] = key_cfg_file_args.pop(flag_key)
                        cfg_file_args.pop(flag_key, None)

                    if isinstance(inner_f, ModuleFlag):
                        top_program_parsed_args[f.params_key][inner_f.params_key] = deep_merge_dict(
                            cfg_file_args.pop(inner_f.params_key, {}) or {},
                            deep_merge_dict(top_program_parsed_args[f.params_key].pop(inner_f.params_key, {}) or {},
                                            deep_merge_dict(top_program_parsed_args.pop(inner_f.params_key, {}) or {},
                                                            module_parsed_args.pop(inner_f.params_key, {}) or {})))
    top_program_parsed_args = deep_merge_dict(cfg_file_args, top_program_parsed_args)
    for f in flag_list:
        if isinstance(f, Flag):
            if f.name not in top_program_parsed_args or top_program_parsed_args[f.name] is None:
                top_program_parsed_args[f.name] = f.default
    return top_program_parsed_args, remaining_argv
Ejemplo n.º 5
0
def _args_preload_from_config_files(args):
    cfg_file_args = yaml_load_checking(load_from_config_path(
        flatten_string_list(getattr(args, DEFAULT_CONFIG_FLAG.name, None))))
    return cfg_file_args