def load(config: Config, preload_data=True): """Loads a dataset. If preload_data is set, loads entity and relation maps as well as all splits. Otherwise, this data is lazy loaded on first use. """ name = config.get("dataset.name") folder = os.path.join(kge_base_dir(), "data", name) if os.path.isfile(os.path.join(folder, "dataset.yaml")): config.log("Loading configuration of dataset " + name + "...") config.load(os.path.join(folder, "dataset.yaml")) dataset = Dataset(config, folder) if preload_data: dataset.entity_ids() dataset.relation_ids() for split in ["train", "valid", "test"]: dataset.split(split) return dataset
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # package command if args.command == "package": package_model(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print( "WARNING: No configuration specified; using " + args.config, file=sys.stderr, ) if not vars(args)["console.quiet"]: print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" if not vars(args)["console.quiet"]: print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): checkpoint_file = get_checkpoint_file(config, args.checkpoint) # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds def get_seed(what): seed = config.get(f"random_seed.{what}") if seed < 0 and config.get(f"random_seed.default") >= 0: import hashlib # we add an md5 hash to the default seed so that different PRNGs get a # different seed seed = ( config.get(f"random_seed.default") + int(hashlib.md5(what.encode()).hexdigest(), 16) ) % 0xFFFF # stay 32-bit return seed if get_seed("python") > -1: import random random.seed(get_seed("python")) if get_seed("torch") > -1: import torch torch.manual_seed(get_seed("torch")) if get_seed("numpy") > -1: import numpy.random numpy.random.seed(get_seed("numpy")) if get_seed("numba") > -1: import numpy as np, numba @numba.njit def seed_numba(seed): np.random.seed(seed) seed_numba(get_seed("numba")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.create(config) # let's go if args.command == "resume": if checkpoint_file is not None: checkpoint = load_checkpoint( checkpoint_file, config.get("job.device") ) job = Job.create_from( checkpoint, new_config=config, dataset=dataset ) else: job = Job.create(config, dataset) job.config.log( "No checkpoint found or specified, starting from scratch..." ) else: job = Job.create(config, dataset) job.run() except BaseException: tb = traceback.format_exc() config.log(tb, echo=False) raise
import unittest from tests.util import get_dataset_folder import sys from kge.misc import kge_base_dir import os from os import path sys.path.append(path.join(kge_base_dir(), "data/preprocess")) from data.preprocess.util import analyze_raw_splits from data.preprocess.util import RawDataset from data.preprocess.util import Split from data.preprocess.util import SampledSplit from data.preprocess.util import FilteredSplit from data.preprocess.util import RawSplit from data.preprocess.util import write_dataset_yaml from data.preprocess.util import process_splits import yaml class TestPreprocess(unittest.TestCase): def setUp(self) -> None: self.dataset_name = "dataset_preprocess" self.dataset_folder = get_dataset_folder(self.dataset_name) def tearDown(self) -> None: self.remove_del_files() def test_analyze_splits(self): raw_splits = TestPreprocess.get_raw_splits() raw_dataset: RawDataset = analyze_raw_splits(
def get_dataset_folder(dataset_name): return os.path.join(kge_base_dir(), "tests", "data", dataset_name)
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print("WARNING: No configuration specified; using " + args.config) print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): if args.checkpoint == "default": if config.get("job.type") in ["eval", "valid"]: checkpoint_file = config.checkpoint_file("best") else: checkpoint_file = None # means last elif is_number(args.checkpoint, int) or args.checkpoint == "best": checkpoint_file = config.checkpoint_file(args.checkpoint) else: # otherwise, treat it as a filename checkpoint_file = args.checkpoint # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds if config.get("random_seed.python") > -1: import random random.seed(config.get("random_seed.python")) if config.get("random_seed.torch") > -1: import torch torch.manual_seed(config.get("random_seed.torch")) if config.get("random_seed.numpy") > -1: import numpy.random numpy.random.seed(config.get("random_seed.numpy")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.load(config) # let's go job = Job.create(config, dataset) if args.command == "resume": job.resume(checkpoint_file) job.run() except BaseException as e: tb = traceback.format_exc() config.log(tb, echo=False) raise e from None