def create( config: Config, dataset: Dataset, configuration_key: Optional[str] = None, init_for_load_only=False, ) -> "KgeModel": """Factory method for model creation.""" try: if configuration_key is not None: model_name = config.get(configuration_key + ".type") else: model_name = config.get("model") config._import(model_name) class_name = config.get(model_name + ".class_name") except: raise Exception("Can't find {}.type in config".format(configuration_key)) try: model = init_from( class_name, config.get("modules"), config=config, dataset=dataset, configuration_key=configuration_key, init_for_load_only=init_for_load_only, ) model.to(config.get("job.device")) return model except: config.log(f"Failed to create model {model_name} (class {class_name}).") raise
def create_config(test_dataset_name: str, model: str = "complex") -> Config: config = Config() config.folder = None config.set("console.quiet", True) config.set("model", model) config._import(model) config.set("dataset.name", test_dataset_name) config.set("job.device", "cpu") return config
def _dump_config(args): """Execute the 'dump config' command.""" if not (args.raw or args.full or args.minimal): args.minimal = True if args.raw + args.full + args.minimal != 1: raise ValueError( "Exactly one of --raw, --full, or --minimal must be set") if args.raw and (args.include or args.exclude): raise ValueError("--include and --exclude cannot be used with --raw " "(use --full or --minimal instead).") config = Config() config_file = None if os.path.isdir(args.source): config_file = os.path.join(args.source, "config.yaml") config.load(config_file) elif ".yaml" in os.path.split(args.source)[-1]: config_file = args.source config.load(config_file) else: # a checkpoint checkpoint = torch.load(args.source, map_location="cpu") if args.raw: config = checkpoint["config"] else: config.load_config(checkpoint["config"]) def print_options(options): # drop all arguments that are not included if args.include: args.include = set(args.include) options_copy = copy.deepcopy(options) for key in options_copy.keys(): prefix = key keep = False while True: if prefix in args.include: keep = True break else: last_dot_index = prefix.rfind(".") if last_dot_index < 0: break else: prefix = prefix[:last_dot_index] if not keep: del options[key] # remove all arguments that are excluded if args.exclude: args.exclude = set(args.exclude) options_copy = copy.deepcopy(options) for key in options_copy.keys(): prefix = key while True: if prefix in args.exclude: del options[key] break else: last_dot_index = prefix.rfind(".") if last_dot_index < 0: break else: prefix = prefix[:last_dot_index] # convert the remaining options to a Config and print it config = Config(load_default=False) config.set_all(options, create=True) print(yaml.dump(config.options)) if args.raw: if config_file: with open(config_file, "r") as f: print(f.read()) else: print_options(config.options) elif args.full: print_options(config.options) else: # minimal default_config = Config() imports = config.get("import") if imports is not None: if not isinstance(imports, list): imports = [imports] for module_name in imports: default_config._import(module_name) default_options = Config.flatten(default_config.options) new_options = Config.flatten(config.options) minimal_options = {} for option, value in new_options.items(): if option not in default_options or default_options[ option] != value: minimal_options[option] = value # always retain all imports if imports is not None: minimal_options["import"] = list(set(imports)) print_options(minimal_options)
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # package command if args.command == "package": package_model(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print( "WARNING: No configuration specified; using " + args.config, file=sys.stderr, ) if not vars(args)["console.quiet"]: print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" if not vars(args)["console.quiet"]: print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): checkpoint_file = get_checkpoint_file(config, args.checkpoint) # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds def get_seed(what): seed = config.get(f"random_seed.{what}") if seed < 0 and config.get(f"random_seed.default") >= 0: import hashlib # we add an md5 hash to the default seed so that different PRNGs get a # different seed seed = ( config.get(f"random_seed.default") + int(hashlib.md5(what.encode()).hexdigest(), 16) ) % 0xFFFF # stay 32-bit return seed if get_seed("python") > -1: import random random.seed(get_seed("python")) if get_seed("torch") > -1: import torch torch.manual_seed(get_seed("torch")) if get_seed("numpy") > -1: import numpy.random numpy.random.seed(get_seed("numpy")) if get_seed("numba") > -1: import numpy as np, numba @numba.njit def seed_numba(seed): np.random.seed(seed) seed_numba(get_seed("numba")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.create(config) # let's go if args.command == "resume": if checkpoint_file is not None: checkpoint = load_checkpoint( checkpoint_file, config.get("job.device") ) job = Job.create_from( checkpoint, new_config=config, dataset=dataset ) else: job = Job.create(config, dataset) job.config.log( "No checkpoint found or specified, starting from scratch..." ) else: job = Job.create(config, dataset) job.run() except BaseException: tb = traceback.format_exc() config.log(tb, echo=False) raise
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print("WARNING: No configuration specified; using " + args.config) print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): if args.checkpoint == "default": if config.get("job.type") in ["eval", "valid"]: checkpoint_file = config.checkpoint_file("best") else: checkpoint_file = None # means last elif is_number(args.checkpoint, int) or args.checkpoint == "best": checkpoint_file = config.checkpoint_file(args.checkpoint) else: # otherwise, treat it as a filename checkpoint_file = args.checkpoint # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds if config.get("random_seed.python") > -1: import random random.seed(config.get("random_seed.python")) if config.get("random_seed.torch") > -1: import torch torch.manual_seed(config.get("random_seed.torch")) if config.get("random_seed.numpy") > -1: import numpy.random numpy.random.seed(config.get("random_seed.numpy")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.load(config) # let's go job = Job.create(config, dataset) if args.command == "resume": job.resume(checkpoint_file) job.run() except BaseException as e: tb = traceback.format_exc() config.log(tb, echo=False) raise e from None