def test_data_pickle_correctness(self): # this will create new pickle files for train, valid, test dataset = Dataset.create(config=self.config, folder=self.dataset_folder, preload_data=True) # create new dataset which loads the triples from stored pckl files dataset_load_by_pickle = Dataset.create(config=self.config, folder=self.dataset_folder, preload_data=True) for split in dataset._triples.keys(): self.assertTrue( torch.all( torch.eq(dataset_load_by_pickle.split(split), dataset.split(split)))) self.assertEqual(dataset._meta, dataset_load_by_pickle._meta)
def create(config: Config, dataset: Optional[Dataset] = None, parent_job=None, model=None): "Create a new job." from kge.job import TrainingJob, EvaluationJob, SearchJob if dataset is None: dataset = Dataset.create(config) job_type = config.get("job.type") if job_type == "train": return TrainingJob.create(config, dataset, parent_job=parent_job, model=model) elif job_type == "search": return SearchJob.create(config, dataset, parent_job=parent_job) elif job_type == "eval": return EvaluationJob.create(config, dataset, parent_job=parent_job, model=model) else: raise ValueError("unknown job type")
def _create_dataset_and_indexes(): data = Dataset.create(config=self.config, folder=self.dataset_folder, preload_data=True) indexes = [] for index_key in data.index_functions.keys(): indexes.append(data.index(index_key)) return data, indexes
def setUp(self): self.config = create_config(self.dataset_name, model=self.model_name) self.config.set_all({"lookup_embedder.dim": 32}) self.config.set_all(self.options) self.dataset_folder = get_dataset_folder(self.dataset_name) self.dataset = Dataset.create( self.config, folder=get_dataset_folder(self.dataset_name) ) self.model = KgeModel.create(self.config, self.dataset)
def setUp(self): self.config = create_config(self.dataset_name) self.config.set_all({"lookup_embedder.dim": 32}) self.config.set("job.type", "train") self.config.set("train.type", self.train_type) self.config.set_all(self.options) self.dataset_folder = get_dataset_folder(self.dataset_name) self.dataset = Dataset.create(self.config, folder=get_dataset_folder( self.dataset_name)) self.model = KgeModel.create(self.config, self.dataset)
def test_store_index_pickle(self): dataset = Dataset.create(config=self.config, folder=self.dataset_folder, preload_data=True) for index_key in dataset.index_functions.keys(): dataset.index(index_key) pickle_filename = os.path.join( self.dataset_folder, Dataset._to_valid_filename(f"index-{index_key}.pckl"), ) self.assertTrue( os.path.isfile( os.path.join(self.dataset_folder, pickle_filename)), msg=pickle_filename, )
def test_store_data_pickle(self): # this will create new pickle files for train, valid, test dataset = Dataset.create(config=self.config, folder=self.dataset_folder, preload_data=True) pickle_filenames = [ "train.del-t.pckl", "valid.del-t.pckl", "test.del-t.pckl", "entity_ids.del-True-t-False.pckl", "relation_ids.del-True-t-False.pckl", ] for filename in pickle_filenames: self.assertTrue( os.path.isfile(os.path.join(self.dataset_folder, filename)), msg=filename, )
def create_default( model: Optional[str] = None, dataset: Optional[Union[Dataset, str]] = None, options: Dict[str, Any] = {}, folder: Optional[str] = None, ) -> "KgeModel": """Utility method to create a model, including configuration and dataset. `model` is the name of the model (takes precedence over ``options["model"]``), `dataset` a dataset name or `Dataset` instance (takes precedence over ``options["dataset.name"]``), and options arbitrary other configuration options. If `folder` is ``None``, creates a temporary folder. Otherwise uses the specified folder. """ # load default model config if model is None: model = options["model"] default_config_file = filename_in_module(kge.model, "{}.yaml".format(model)) config = Config() config.load(default_config_file, create=True) # apply specified options config.set("model", model) if isinstance(dataset, Dataset): config.set("dataset.name", dataset.config.get("dataset.name")) elif isinstance(dataset, str): config.set("dataset.name", dataset) config.set_all(new_options=options) # create output folder if folder is None: config.folder = tempfile.mkdtemp( "{}-{}-".format(config.get("dataset.name"), config.get("model")) ) else: config.folder = folder # create dataset and model if not isinstance(dataset, Dataset): dataset = Dataset.create(config) model = KgeModel.create(config, dataset) return model
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # package command if args.command == "package": package_model(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print( "WARNING: No configuration specified; using " + args.config, file=sys.stderr, ) if not vars(args)["console.quiet"]: print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" if not vars(args)["console.quiet"]: print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): checkpoint_file = get_checkpoint_file(config, args.checkpoint) # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds def get_seed(what): seed = config.get(f"random_seed.{what}") if seed < 0 and config.get(f"random_seed.default") >= 0: import hashlib # we add an md5 hash to the default seed so that different PRNGs get a # different seed seed = ( config.get(f"random_seed.default") + int(hashlib.md5(what.encode()).hexdigest(), 16) ) % 0xFFFF # stay 32-bit return seed if get_seed("python") > -1: import random random.seed(get_seed("python")) if get_seed("torch") > -1: import torch torch.manual_seed(get_seed("torch")) if get_seed("numpy") > -1: import numpy.random numpy.random.seed(get_seed("numpy")) if get_seed("numba") > -1: import numpy as np, numba @numba.njit def seed_numba(seed): np.random.seed(seed) seed_numba(get_seed("numba")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.create(config) # let's go if args.command == "resume": if checkpoint_file is not None: checkpoint = load_checkpoint( checkpoint_file, config.get("job.device") ) job = Job.create_from( checkpoint, new_config=config, dataset=dataset ) else: job = Job.create(config, dataset) job.config.log( "No checkpoint found or specified, starting from scratch..." ) else: job = Job.create(config, dataset) job.run() except BaseException: tb = traceback.format_exc() config.log(tb, echo=False) raise