コード例 #1
0
def _pre_load_args(args):
    cfg_file_args = yaml_load_checking(
        load_from_config_path(
            flatten_string_list(
                getattr(args, flags_core.DEFAULT_CONFIG_FLAG.name))))
    model_dirs = flatten_string_list(args.model_dir
                                     or cfg_file_args.get("model_dir", None))
    hparams_set = args.hparams_set
    if hparams_set is None:
        hparams_set = cfg_file_args.get("hparams_set", None)
    predefined_parameters = get_hyper_parameters(hparams_set)
    formatted_parameters = {}
    if "model.class" in predefined_parameters:
        formatted_parameters["model.class"] = predefined_parameters.pop(
            "model.class")
    if "model" in predefined_parameters:
        formatted_parameters["model"] = predefined_parameters.pop("model")
    if "model.params" in predefined_parameters:
        formatted_parameters["model.params"] = predefined_parameters.pop(
            "model.params")
    if len(predefined_parameters) > 0:
        formatted_parameters["entry.params"] = predefined_parameters

    try:
        model_cfgs = ModelConfigs.load(model_dirs[0])
        return deep_merge_dict(
            deep_merge_dict(model_cfgs, formatted_parameters), cfg_file_args)
    except Exception:
        return deep_merge_dict(formatted_parameters, cfg_file_args)
コード例 #2
0
ファイル: flags_core.py プロジェクト: lileicc/neurst
def _args_preload_from_config_files(args):
    cfg_file_args = yaml_load_checking(load_from_config_path(
        flatten_string_list(getattr(args, DEFAULT_CONFIG_FLAG.name, None))))
    return cfg_file_args