def _pre_load_args(args): cfg_file_args = yaml_load_checking( load_from_config_path( flatten_string_list( getattr(args, flags_core.DEFAULT_CONFIG_FLAG.name)))) model_dirs = flatten_string_list(args.model_dir or cfg_file_args.get("model_dir", None)) hparams_set = args.hparams_set if hparams_set is None: hparams_set = cfg_file_args.get("hparams_set", None) predefined_parameters = get_hyper_parameters(hparams_set) formatted_parameters = {} if "model.class" in predefined_parameters: formatted_parameters["model.class"] = predefined_parameters.pop( "model.class") if "model" in predefined_parameters: formatted_parameters["model"] = predefined_parameters.pop("model") if "model.params" in predefined_parameters: formatted_parameters["model.params"] = predefined_parameters.pop( "model.params") if len(predefined_parameters) > 0: formatted_parameters["entry.params"] = predefined_parameters try: model_cfgs = ModelConfigs.load(model_dirs[0]) return deep_merge_dict( deep_merge_dict(model_cfgs, formatted_parameters), cfg_file_args) except Exception: return deep_merge_dict(formatted_parameters, cfg_file_args)
def _args_preload_from_config_files(args): cfg_file_args = yaml_load_checking(load_from_config_path( flatten_string_list(getattr(args, DEFAULT_CONFIG_FLAG.name, None)))) return cfg_file_args