def __init__( self, config: Config, dataset: Dataset, configuration_key=None, init_for_load_only=False, ): self._init_configuration(config, configuration_key) # Initialize base model # Using a dataset with twice the number of relations to initialize base model alt_dataset = dataset.shallow_copy() alt_dataset._num_relations = dataset.num_relations() * 2 base_model = KgeModel.create( config=config, dataset=alt_dataset, configuration_key=self.configuration_key + ".base_model", init_for_load_only=init_for_load_only, ) # Initialize this model super().__init__( config=config, dataset=dataset, scorer=base_model.get_scorer(), create_embedders=False, init_for_load_only=init_for_load_only, ) self._base_model = base_model # TODO change entity_embedder assignment to sub and obj embedders when support # for that is added self._entity_embedder = self._base_model.get_s_embedder() self._relation_embedder = self._base_model.get_p_embedder()
def test_store_index_pickle(self): dataset = Dataset.create(config=self.config, folder=self.dataset_folder, preload_data=True) for index_key in dataset.index_functions.keys(): dataset.index(index_key) pickle_filename = os.path.join( self.dataset_folder, Dataset._to_valid_filename(f"index-{index_key}.pckl"), ) self.assertTrue( os.path.isfile( os.path.join(self.dataset_folder, pickle_filename)), msg=pickle_filename, )
def create_from( checkpoint: Dict, dataset: Optional[Dataset] = None, use_tmp_log_folder=True, new_config: Config = None, ) -> "KgeModel": """Loads a model from a checkpoint file of a training job or a packaged model. If dataset is specified, associates this dataset with the model. Otherwise uses the dataset used to train the model. If `use_tmp_log_folder` is set, the logs and traces are written to a temporary file. Otherwise, the files `kge.log` and `trace.yaml` will be created (or appended to) in the checkpoint's folder. """ config = Config.create_from(checkpoint) if new_config: config.load_config(new_config) if use_tmp_log_folder: import tempfile config.log_folder = tempfile.mkdtemp(prefix="kge-") else: config.log_folder = checkpoint["folder"] if not config.log_folder or not os.path.exists(config.log_folder): config.log_folder = "." dataset = Dataset.create_from(checkpoint, config, dataset, preload_data=False) model = KgeModel.create(config, dataset, init_for_load_only=True) model.load(checkpoint["model"]) model.eval() return model
def package_model(args): """ Converts a checkpoint to a packaged model. A packaged model only contains the model, entity/relation ids and the config. """ checkpoint_file = args.checkpoint filename = args.file checkpoint = load_checkpoint(checkpoint_file, device="cpu") if checkpoint["type"] != "train": raise ValueError("Can only package trained checkpoints.") config = Config.create_from(checkpoint) dataset = Dataset.create_from(checkpoint, config, preload_data=False) packaged_model = { "type": "package", "model": checkpoint["model"], "epoch": checkpoint["epoch"], "job_id": checkpoint["job_id"], "valid_trace": checkpoint["valid_trace"], } packaged_model = config.save_to(packaged_model) packaged_model = dataset.save_to( packaged_model, ["entity_ids", "relation_ids"], ) if filename is None: output_folder, filename = os.path.split(checkpoint_file) if "checkpoint" in filename: filename = filename.replace("checkpoint", "model") else: filename = filename.split(".pt")[0] + "_package.pt" filename = os.path.join(output_folder, filename) print(f"Saving to {filename}...") torch.save(packaged_model, filename)
def create(config: Config, dataset: Optional[Dataset] = None, parent_job=None, model=None): "Create a new job." from kge.job import TrainingJob, EvaluationJob, SearchJob if dataset is None: dataset = Dataset.create(config) job_type = config.get("job.type") if job_type == "train": return TrainingJob.create(config, dataset, parent_job=parent_job, model=model) elif job_type == "search": return SearchJob.create(config, dataset, parent_job=parent_job) elif job_type == "eval": return EvaluationJob.create(config, dataset, parent_job=parent_job, model=model) else: raise ValueError("unknown job type")
def test_data_pickle_correctness(self): # this will create new pickle files for train, valid, test dataset = Dataset.create(config=self.config, folder=self.dataset_folder, preload_data=True) # create new dataset which loads the triples from stored pckl files dataset_load_by_pickle = Dataset.create(config=self.config, folder=self.dataset_folder, preload_data=True) for split in dataset._triples.keys(): self.assertTrue( torch.all( torch.eq(dataset_load_by_pickle.split(split), dataset.split(split)))) self.assertEqual(dataset._meta, dataset_load_by_pickle._meta)
def load_from_checkpoint(filename: str, dataset=None, use_tmp_log_folder=True, device="cpu") -> "KgeModel": """Loads a model from a checkpoint file of a training job. If dataset is specified, associates this dataset with the model. Otherwise uses the dataset used to train the model. If `use_tmp_log_folder` is set, the logs and traces are written to a temporary file. Otherwise, the files `kge.log` and `trace.yaml` will be created (or appended to) in the checkpoint's folder. """ checkpoint = torch.load(filename, map_location=device) original_config = checkpoint["config"] config = Config() # round trip to handle deprecated configs config.load_options(original_config.options) config.set("job.device", device) if use_tmp_log_folder: import tempfile config.log_folder = tempfile.mkdtemp(prefix="kge-") else: config.log_folder = os.path.dirname(filename) if not config.log_folder: config.log_folder = "." if dataset is None: dataset = Dataset.load(config, preload_data=False) model = KgeModel.create(config, dataset) model.load(checkpoint["model"]) model.eval() return model
def __init__(self, config: Config, configuration_key: str, dataset: Dataset): super().__init__(config, configuration_key) # load config self.num_samples = torch.zeros(3, dtype=torch.int) self.filter_positives = torch.zeros(3, dtype=torch.bool) self.vocabulary_size = torch.zeros(3, dtype=torch.int) self.shared = self.get_option("shared") self.with_replacement = self.get_option("with_replacement") if not self.with_replacement and not self.shared: raise ValueError( "Without replacement sampling is only supported when " "shared negative sampling is enabled.") self.filtering_split = config.get("negative_sampling.filtering.split") if self.filtering_split == "": self.filtering_split = config.get("train.split") for slot in SLOTS: slot_str = SLOT_STR[slot] self.num_samples[slot] = self.get_option(f"num_samples.{slot_str}") self.filter_positives[slot] = self.get_option( f"filtering.{slot_str}") self.vocabulary_size[slot] = (dataset.num_relations() if slot == P else dataset.num_entities()) # create indices for filtering here already if needed and not existing # otherwise every worker would create every index again and again if self.filter_positives[slot]: pair = ["po", "so", "sp"][slot] dataset.index(f"{self.filtering_split}_{pair}_to_{slot_str}") if any(self.filter_positives): if self.shared: raise ValueError( "Filtering is not supported when shared negative sampling is enabled." ) self.check_option("filtering.implementation", ["standard", "fast", "fast_if_available"]) self.filter_implementation = self.get_option( "filtering.implementation") self.dataset = dataset # auto config for slot, copy_from in [(S, O), (P, None), (O, S)]: if self.num_samples[slot] < 0: if copy_from is not None and self.num_samples[copy_from] > 0: self.num_samples[slot] = self.num_samples[copy_from] else: self.num_samples[slot] = 0
def _create_dataset_and_indexes(): data = Dataset.create(config=self.config, folder=self.dataset_folder, preload_data=True) indexes = [] for index_key in data.index_functions.keys(): indexes.append(data.index(index_key)) return data, indexes
def setUp(self): self.config = create_config(self.dataset_name, model=self.model_name) self.config.set_all({"lookup_embedder.dim": 32}) self.config.set_all(self.options) self.dataset_folder = get_dataset_folder(self.dataset_name) self.dataset = Dataset.create( self.config, folder=get_dataset_folder(self.dataset_name) ) self.model = KgeModel.create(self.config, self.dataset)
def __init__( self, config: Config, dataset: Dataset, scorer: Union[RelationalScorer, type], initialize_embedders=True, configuration_key=None, ): super().__init__(config, dataset, configuration_key) # TODO support different embedders for subjects and objects #: Embedder used for entities (both subject and objects) self._entity_embedder: KgeEmbedder #: Embedder used for relations self._relation_embedder: KgeEmbedder if initialize_embedders: self._entity_embedder = KgeEmbedder.create( config, dataset, self.configuration_key + ".entity_embedder", dataset.num_entities(), ) #: Embedder used for relations num_relations = dataset.num_relations() self._relation_embedder = KgeEmbedder.create( config, dataset, self.configuration_key + ".relation_embedder", num_relations, ) #: Scorer self._scorer: RelationalScorer if type(scorer) == type: # scorer is type of the scorer to use; call its constructor self._scorer = scorer(config=config, dataset=dataset, configuration_key=self.configuration_key) else: self._scorer = scorer
def setUp(self): self.config = create_config(self.dataset_name) self.config.set_all({"lookup_embedder.dim": 32}) self.config.set("job.type", "train") self.config.set("train.type", self.train_type) self.config.set_all(self.options) self.dataset_folder = get_dataset_folder(self.dataset_name) self.dataset = Dataset.create(self.config, folder=get_dataset_folder( self.dataset_name)) self.model = KgeModel.create(self.config, self.dataset)
def test_store_data_pickle(self): # this will create new pickle files for train, valid, test dataset = Dataset.create(config=self.config, folder=self.dataset_folder, preload_data=True) pickle_filenames = [ "train.del-t.pckl", "valid.del-t.pckl", "test.del-t.pckl", "entity_ids.del-True-t-False.pckl", "relation_ids.del-True-t-False.pckl", ] for filename in pickle_filenames: self.assertTrue( os.path.isfile(os.path.join(self.dataset_folder, filename)), msg=filename, )
def create_default( model: Optional[str] = None, dataset: Optional[Union[Dataset, str]] = None, options: Dict[str, Any] = {}, folder: Optional[str] = None, ) -> "KgeModel": """Utility method to create a model, including configuration and dataset. `model` is the name of the model (takes precedence over ``options["model"]``), `dataset` a dataset name or `Dataset` instance (takes precedence over ``options["dataset.name"]``), and options arbitrary other configuration options. If `folder` is ``None``, creates a temporary folder. Otherwise uses the specified folder. """ # load default model config if model is None: model = options["model"] default_config_file = filename_in_module(kge.model, "{}.yaml".format(model)) config = Config() config.load(default_config_file, create=True) # apply specified options config.set("model", model) if isinstance(dataset, Dataset): config.set("dataset.name", dataset.config.get("dataset.name")) elif isinstance(dataset, str): config.set("dataset.name", dataset) config.set_all(new_options=options) # create output folder if folder is None: config.folder = tempfile.mkdtemp( "{}-{}-".format(config.get("dataset.name"), config.get("model")) ) else: config.folder = folder # create dataset and model if not isinstance(dataset, Dataset): dataset = Dataset.create(config) model = KgeModel.create(config, dataset) return model
def create_from(cls, checkpoint: Dict, new_config: Config = None, dataset: Dataset = None, parent_job=None, parameter_client=None) -> Job: """ Creates a Job based on a checkpoint Args: checkpoint: loaded checkpoint new_config: optional config object - overwrites options of config stored in checkpoint dataset: dataset object parent_job: parent job (e.g. search job) Returns: Job based on checkpoint """ from kge.model import KgeModel model: KgeModel = None # search jobs don't have a model if "model" in checkpoint and checkpoint["model"] is not None: model = KgeModel.create_from(checkpoint, new_config=new_config, dataset=dataset, parameter_client=parameter_client) config = model.config dataset = model.dataset else: config = Config.create_from(checkpoint) if new_config: config.load_config(new_config) dataset = Dataset.create_from(checkpoint, config, dataset) job = Job.create(config, dataset, parent_job, model, parameter_client=parameter_client, init_for_load_only=True) job._load(checkpoint) job.config.log("Loaded checkpoint from {}...".format( checkpoint["file"])) return job
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # package command if args.command == "package": package_model(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print( "WARNING: No configuration specified; using " + args.config, file=sys.stderr, ) if not vars(args)["console.quiet"]: print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" if not vars(args)["console.quiet"]: print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): checkpoint_file = get_checkpoint_file(config, args.checkpoint) # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds def get_seed(what): seed = config.get(f"random_seed.{what}") if seed < 0 and config.get(f"random_seed.default") >= 0: import hashlib # we add an md5 hash to the default seed so that different PRNGs get a # different seed seed = ( config.get(f"random_seed.default") + int(hashlib.md5(what.encode()).hexdigest(), 16) ) % 0xFFFF # stay 32-bit return seed if get_seed("python") > -1: import random random.seed(get_seed("python")) if get_seed("torch") > -1: import torch torch.manual_seed(get_seed("torch")) if get_seed("numpy") > -1: import numpy.random numpy.random.seed(get_seed("numpy")) if get_seed("numba") > -1: import numpy as np, numba @numba.njit def seed_numba(seed): np.random.seed(seed) seed_numba(get_seed("numba")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.create(config) # let's go if args.command == "resume": if checkpoint_file is not None: checkpoint = load_checkpoint( checkpoint_file, config.get("job.device") ) job = Job.create_from( checkpoint, new_config=config, dataset=dataset ) else: job = Job.create(config, dataset) job.config.log( "No checkpoint found or specified, starting from scratch..." ) else: job = Job.create(config, dataset) job.run() except BaseException: tb = traceback.format_exc() config.log(tb, echo=False) raise
def __init__( self, config: Config, dataset: Dataset, scorer: Union[RelationalScorer, type], create_embedders=True, configuration_key=None, init_for_load_only=False, ): super().__init__(config, dataset, configuration_key) # TODO support different embedders for subjects and objects #: Embedder used for entities (both subject and objects) self._entity_embedder: KgeEmbedder #: Embedder used for relations self._relation_embedder: KgeEmbedder if create_embedders: self._entity_embedder = KgeEmbedder.create( config, dataset, self.configuration_key + ".entity_embedder", dataset.num_entities(), init_for_load_only=init_for_load_only, ) #: Embedder used for relations num_relations = dataset.num_relations() self._relation_embedder = KgeEmbedder.create( config, dataset, self.configuration_key + ".relation_embedder", num_relations, init_for_load_only=init_for_load_only, ) if not init_for_load_only: # load pretrained embeddings pretrained_entities_filename = "" pretrained_relations_filename = "" if self.has_option("entity_embedder.pretrain.model_filename"): pretrained_entities_filename = self.get_option( "entity_embedder.pretrain.model_filename" ) if self.has_option("relation_embedder.pretrain.model_filename"): pretrained_relations_filename = self.get_option( "relation_embedder.pretrain.model_filename" ) def load_pretrained_model( pretrained_filename: str, ) -> Optional[KgeModel]: if pretrained_filename != "": self.config.log( f"Initializing with embeddings stored in " f"{pretrained_filename}" ) checkpoint = load_checkpoint(pretrained_filename) return KgeModel.create_from(checkpoint) return None pretrained_entities_model = load_pretrained_model( pretrained_entities_filename ) if pretrained_entities_filename == pretrained_relations_filename: pretrained_relations_model = pretrained_entities_model else: pretrained_relations_model = load_pretrained_model( pretrained_relations_filename ) if pretrained_entities_model is not None: if ( pretrained_entities_model.get_s_embedder() != pretrained_entities_model.get_o_embedder() ): raise ValueError( "Can only initialize with pre-trained models having " "identical subject and object embeddings." ) self._entity_embedder.init_pretrained( pretrained_entities_model.get_s_embedder() ) if pretrained_relations_model is not None: self._relation_embedder.init_pretrained( pretrained_relations_model.get_p_embedder() ) #: Scorer self._scorer: RelationalScorer if type(scorer) == type: # scorer is type of the scorer to use; call its constructor self._scorer = scorer( config=config, dataset=dataset, configuration_key=self.configuration_key ) else: self._scorer = scorer
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print("WARNING: No configuration specified; using " + args.config) print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): if args.checkpoint == "default": if config.get("job.type") in ["eval", "valid"]: checkpoint_file = config.checkpoint_file("best") else: checkpoint_file = None # means last elif is_number(args.checkpoint, int) or args.checkpoint == "best": checkpoint_file = config.checkpoint_file(args.checkpoint) else: # otherwise, treat it as a filename checkpoint_file = args.checkpoint # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds if config.get("random_seed.python") > -1: import random random.seed(config.get("random_seed.python")) if config.get("random_seed.torch") > -1: import torch torch.manual_seed(config.get("random_seed.torch")) if config.get("random_seed.numpy") > -1: import numpy.random numpy.random.seed(config.get("random_seed.numpy")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.load(config) # let's go job = Job.create(config, dataset) if args.command == "resume": job.resume(checkpoint_file) job.run() except BaseException as e: tb = traceback.format_exc() config.log(tb, echo=False) raise e from None
def __init__( self, config: Config, dataset: Dataset, scorer: Union[RelationalScorer, type], create_embedders=True, configuration_key=None, init_for_load_only=False, parameter_client=None, max_partition_entities=0, ): super().__init__(config, dataset, configuration_key) # TODO support different embedders for subjects and objects #: Embedder used for entities (both subject and objects) self._entity_embedder: KgeEmbedder #: Embedder used for relations self._relation_embedder: KgeEmbedder if create_embedders: self._create_embedders(init_for_load_only) elif False: #if self.get_option("create_complete"): # embedding_layer_size = dataset.num_entities() if config.get("job.distributed.entity_sync_level") == "partition" and max_partition_entities != 0: embedding_layer_size =max_partition_entities else: embedding_layer_size = self._calc_embedding_layer_size(config, dataset) config.log(f"creating entity_embedder with {embedding_layer_size} keys") self._entity_embedder = KgeEmbedder.create( config=config, dataset=dataset, configuration_key=self.configuration_key + ".entity_embedder", #dataset.num_entities(), vocab_size=embedding_layer_size, init_for_load_only=init_for_load_only, parameter_client=parameter_client, lapse_offset=0, complete_vocab_size=dataset.num_entities() ) #: Embedder used for relations num_relations = dataset.num_relations() self._relation_embedder = KgeEmbedder.create( config, dataset, self.configuration_key + ".relation_embedder", num_relations, init_for_load_only=init_for_load_only, parameter_client=parameter_client, lapse_offset=dataset.num_entities(), complete_vocab_size=dataset.num_relations(), ) if not init_for_load_only and parameter_client.rank == get_min_rank(config): # load pretrained embeddings pretrained_entities_filename = "" pretrained_relations_filename = "" if self.has_option("entity_embedder.pretrain.model_filename"): pretrained_entities_filename = self.get_option( "entity_embedder.pretrain.model_filename" ) if self.has_option("relation_embedder.pretrain.model_filename"): pretrained_relations_filename = self.get_option( "relation_embedder.pretrain.model_filename" ) def load_pretrained_model( pretrained_filename: str, ) -> Optional[KgeModel]: if pretrained_filename != "": self.config.log( f"Initializing with embeddings stored in " f"{pretrained_filename}" ) checkpoint = load_checkpoint(pretrained_filename) return KgeModel.create_from(checkpoint, parameter_client=parameter_client) return None pretrained_entities_model = load_pretrained_model( pretrained_entities_filename ) if pretrained_entities_filename == pretrained_relations_filename: pretrained_relations_model = pretrained_entities_model else: pretrained_relations_model = load_pretrained_model( pretrained_relations_filename ) if pretrained_entities_model is not None: if ( pretrained_entities_model.get_s_embedder() != pretrained_entities_model.get_o_embedder() ): raise ValueError( "Can only initialize with pre-trained models having " "identical subject and object embeddings." ) self._entity_embedder.init_pretrained( pretrained_entities_model.get_s_embedder() ) if pretrained_relations_model is not None: self._relation_embedder.init_pretrained( pretrained_relations_model.get_p_embedder() ) #: Scorer self._scorer: RelationalScorer if type(scorer) == type: # scorer is type of the scorer to use; call its constructor self._scorer = scorer( config=config, dataset=dataset, configuration_key=self.configuration_key ) else: self._scorer = scorer