def __init__(
        self,
        config: Config,
        dataset: Dataset,
        configuration_key=None,
        init_for_load_only=False,
    ):
        self._init_configuration(config, configuration_key)

        # Initialize base model
        # Using a dataset with twice the number of relations to initialize base model
        alt_dataset = dataset.shallow_copy()
        alt_dataset._num_relations = dataset.num_relations() * 2
        base_model = KgeModel.create(
            config=config,
            dataset=alt_dataset,
            configuration_key=self.configuration_key + ".base_model",
            init_for_load_only=init_for_load_only,
        )

        # Initialize this model
        super().__init__(
            config=config,
            dataset=dataset,
            scorer=base_model.get_scorer(),
            create_embedders=False,
            init_for_load_only=init_for_load_only,
        )
        self._base_model = base_model
        # TODO change entity_embedder assignment to sub and obj embedders when support
        # for that is added
        self._entity_embedder = self._base_model.get_s_embedder()
        self._relation_embedder = self._base_model.get_p_embedder()
Exemple #2
0
 def test_store_index_pickle(self):
     dataset = Dataset.create(config=self.config,
                              folder=self.dataset_folder,
                              preload_data=True)
     for index_key in dataset.index_functions.keys():
         dataset.index(index_key)
         pickle_filename = os.path.join(
             self.dataset_folder,
             Dataset._to_valid_filename(f"index-{index_key}.pckl"),
         )
         self.assertTrue(
             os.path.isfile(
                 os.path.join(self.dataset_folder, pickle_filename)),
             msg=pickle_filename,
         )
Exemple #3
0
    def create_from(
        checkpoint: Dict,
        dataset: Optional[Dataset] = None,
        use_tmp_log_folder=True,
        new_config: Config = None,
    ) -> "KgeModel":
        """Loads a model from a checkpoint file of a training job or a packaged model.

        If dataset is specified, associates this dataset with the model. Otherwise uses
        the dataset used to train the model.

        If `use_tmp_log_folder` is set, the logs and traces are written to a temporary
        file. Otherwise, the files `kge.log` and `trace.yaml` will be created (or
        appended to) in the checkpoint's folder.

        """
        config = Config.create_from(checkpoint)
        if new_config:
            config.load_config(new_config)

        if use_tmp_log_folder:
            import tempfile

            config.log_folder = tempfile.mkdtemp(prefix="kge-")
        else:
            config.log_folder = checkpoint["folder"]
            if not config.log_folder or not os.path.exists(config.log_folder):
                config.log_folder = "."
        dataset = Dataset.create_from(checkpoint, config, dataset, preload_data=False)
        model = KgeModel.create(config, dataset, init_for_load_only=True)
        model.load(checkpoint["model"])
        model.eval()
        return model
Exemple #4
0
def package_model(args):
    """
    Converts a checkpoint to a packaged model.
    A packaged model only contains the model, entity/relation ids and the config.
    """
    checkpoint_file = args.checkpoint
    filename = args.file
    checkpoint = load_checkpoint(checkpoint_file, device="cpu")
    if checkpoint["type"] != "train":
        raise ValueError("Can only package trained checkpoints.")
    config = Config.create_from(checkpoint)
    dataset = Dataset.create_from(checkpoint, config, preload_data=False)
    packaged_model = {
        "type": "package",
        "model": checkpoint["model"],
        "epoch": checkpoint["epoch"],
        "job_id": checkpoint["job_id"],
        "valid_trace": checkpoint["valid_trace"],
    }
    packaged_model = config.save_to(packaged_model)
    packaged_model = dataset.save_to(
        packaged_model,
        ["entity_ids", "relation_ids"],
    )
    if filename is None:
        output_folder, filename = os.path.split(checkpoint_file)
        if "checkpoint" in filename:
            filename = filename.replace("checkpoint", "model")
        else:
            filename = filename.split(".pt")[0] + "_package.pt"
        filename = os.path.join(output_folder, filename)
    print(f"Saving to {filename}...")
    torch.save(packaged_model, filename)
Exemple #5
0
    def create(config: Config,
               dataset: Optional[Dataset] = None,
               parent_job=None,
               model=None):
        "Create a new job."
        from kge.job import TrainingJob, EvaluationJob, SearchJob

        if dataset is None:
            dataset = Dataset.create(config)

        job_type = config.get("job.type")
        if job_type == "train":
            return TrainingJob.create(config,
                                      dataset,
                                      parent_job=parent_job,
                                      model=model)
        elif job_type == "search":
            return SearchJob.create(config, dataset, parent_job=parent_job)
        elif job_type == "eval":
            return EvaluationJob.create(config,
                                        dataset,
                                        parent_job=parent_job,
                                        model=model)
        else:
            raise ValueError("unknown job type")
Exemple #6
0
    def test_data_pickle_correctness(self):
        # this will create new pickle files for train, valid, test
        dataset = Dataset.create(config=self.config,
                                 folder=self.dataset_folder,
                                 preload_data=True)

        # create new dataset which loads the triples from stored pckl files
        dataset_load_by_pickle = Dataset.create(config=self.config,
                                                folder=self.dataset_folder,
                                                preload_data=True)
        for split in dataset._triples.keys():
            self.assertTrue(
                torch.all(
                    torch.eq(dataset_load_by_pickle.split(split),
                             dataset.split(split))))
        self.assertEqual(dataset._meta, dataset_load_by_pickle._meta)
Exemple #7
0
    def load_from_checkpoint(filename: str,
                             dataset=None,
                             use_tmp_log_folder=True,
                             device="cpu") -> "KgeModel":
        """Loads a model from a checkpoint file of a training job.

        If dataset is specified, associates this dataset with the model. Otherwise uses
        the dataset used to train the model.

        If `use_tmp_log_folder` is set, the logs and traces are written to a temporary
        file. Otherwise, the files `kge.log` and `trace.yaml` will be created (or
        appended to) in the checkpoint's folder.

        """

        checkpoint = torch.load(filename, map_location=device)

        original_config = checkpoint["config"]
        config = Config()  # round trip to handle deprecated configs
        config.load_options(original_config.options)
        config.set("job.device", device)
        if use_tmp_log_folder:
            import tempfile

            config.log_folder = tempfile.mkdtemp(prefix="kge-")
        else:
            config.log_folder = os.path.dirname(filename)
            if not config.log_folder:
                config.log_folder = "."
        if dataset is None:
            dataset = Dataset.load(config, preload_data=False)
        model = KgeModel.create(config, dataset)
        model.load(checkpoint["model"])
        model.eval()
        return model
Exemple #8
0
    def __init__(self, config: Config, configuration_key: str,
                 dataset: Dataset):
        super().__init__(config, configuration_key)

        # load config
        self.num_samples = torch.zeros(3, dtype=torch.int)
        self.filter_positives = torch.zeros(3, dtype=torch.bool)
        self.vocabulary_size = torch.zeros(3, dtype=torch.int)
        self.shared = self.get_option("shared")
        self.with_replacement = self.get_option("with_replacement")
        if not self.with_replacement and not self.shared:
            raise ValueError(
                "Without replacement sampling is only supported when "
                "shared negative sampling is enabled.")
        self.filtering_split = config.get("negative_sampling.filtering.split")
        if self.filtering_split == "":
            self.filtering_split = config.get("train.split")
        for slot in SLOTS:
            slot_str = SLOT_STR[slot]
            self.num_samples[slot] = self.get_option(f"num_samples.{slot_str}")
            self.filter_positives[slot] = self.get_option(
                f"filtering.{slot_str}")
            self.vocabulary_size[slot] = (dataset.num_relations() if slot == P
                                          else dataset.num_entities())
            # create indices for filtering here already if needed and not existing
            # otherwise every worker would create every index again and again
            if self.filter_positives[slot]:
                pair = ["po", "so", "sp"][slot]
                dataset.index(f"{self.filtering_split}_{pair}_to_{slot_str}")
        if any(self.filter_positives):
            if self.shared:
                raise ValueError(
                    "Filtering is not supported when shared negative sampling is enabled."
                )
            self.check_option("filtering.implementation",
                              ["standard", "fast", "fast_if_available"])

            self.filter_implementation = self.get_option(
                "filtering.implementation")
        self.dataset = dataset
        # auto config
        for slot, copy_from in [(S, O), (P, None), (O, S)]:
            if self.num_samples[slot] < 0:
                if copy_from is not None and self.num_samples[copy_from] > 0:
                    self.num_samples[slot] = self.num_samples[copy_from]
                else:
                    self.num_samples[slot] = 0
Exemple #9
0
 def _create_dataset_and_indexes():
     data = Dataset.create(config=self.config,
                           folder=self.dataset_folder,
                           preload_data=True)
     indexes = []
     for index_key in data.index_functions.keys():
         indexes.append(data.index(index_key))
     return data, indexes
Exemple #10
0
 def setUp(self):
     self.config = create_config(self.dataset_name, model=self.model_name)
     self.config.set_all({"lookup_embedder.dim": 32})
     self.config.set_all(self.options)
     self.dataset_folder = get_dataset_folder(self.dataset_name)
     self.dataset = Dataset.create(
         self.config, folder=get_dataset_folder(self.dataset_name)
     )
     self.model = KgeModel.create(self.config, self.dataset)
Exemple #11
0
    def __init__(
        self,
        config: Config,
        dataset: Dataset,
        scorer: Union[RelationalScorer, type],
        initialize_embedders=True,
        configuration_key=None,
    ):
        super().__init__(config, dataset, configuration_key)

        # TODO support different embedders for subjects and objects

        #: Embedder used for entities (both subject and objects)
        self._entity_embedder: KgeEmbedder

        #: Embedder used for relations
        self._relation_embedder: KgeEmbedder

        if initialize_embedders:
            self._entity_embedder = KgeEmbedder.create(
                config,
                dataset,
                self.configuration_key + ".entity_embedder",
                dataset.num_entities(),
            )

            #: Embedder used for relations
            num_relations = dataset.num_relations()
            self._relation_embedder = KgeEmbedder.create(
                config,
                dataset,
                self.configuration_key + ".relation_embedder",
                num_relations,
            )

        #: Scorer
        self._scorer: RelationalScorer
        if type(scorer) == type:
            # scorer is type of the scorer to use; call its constructor
            self._scorer = scorer(config=config,
                                  dataset=dataset,
                                  configuration_key=self.configuration_key)
        else:
            self._scorer = scorer
Exemple #12
0
 def setUp(self):
     self.config = create_config(self.dataset_name)
     self.config.set_all({"lookup_embedder.dim": 32})
     self.config.set("job.type", "train")
     self.config.set("train.type", self.train_type)
     self.config.set_all(self.options)
     self.dataset_folder = get_dataset_folder(self.dataset_name)
     self.dataset = Dataset.create(self.config,
                                   folder=get_dataset_folder(
                                       self.dataset_name))
     self.model = KgeModel.create(self.config, self.dataset)
Exemple #13
0
 def test_store_data_pickle(self):
     # this will create new pickle files for train, valid, test
     dataset = Dataset.create(config=self.config,
                              folder=self.dataset_folder,
                              preload_data=True)
     pickle_filenames = [
         "train.del-t.pckl",
         "valid.del-t.pckl",
         "test.del-t.pckl",
         "entity_ids.del-True-t-False.pckl",
         "relation_ids.del-True-t-False.pckl",
     ]
     for filename in pickle_filenames:
         self.assertTrue(
             os.path.isfile(os.path.join(self.dataset_folder, filename)),
             msg=filename,
         )
Exemple #14
0
    def create_default(
        model: Optional[str] = None,
        dataset: Optional[Union[Dataset, str]] = None,
        options: Dict[str, Any] = {},
        folder: Optional[str] = None,
    ) -> "KgeModel":
        """Utility method to create a model, including configuration and dataset.

        `model` is the name of the model (takes precedence over
        ``options["model"]``), `dataset` a dataset name or `Dataset` instance (takes
        precedence over ``options["dataset.name"]``), and options arbitrary other
        configuration options.

        If `folder` is ``None``, creates a temporary folder. Otherwise uses the
        specified folder.

        """
        # load default model config
        if model is None:
            model = options["model"]
        default_config_file = filename_in_module(kge.model, "{}.yaml".format(model))
        config = Config()
        config.load(default_config_file, create=True)

        # apply specified options
        config.set("model", model)
        if isinstance(dataset, Dataset):
            config.set("dataset.name", dataset.config.get("dataset.name"))
        elif isinstance(dataset, str):
            config.set("dataset.name", dataset)
        config.set_all(new_options=options)

        # create output folder
        if folder is None:
            config.folder = tempfile.mkdtemp(
                "{}-{}-".format(config.get("dataset.name"), config.get("model"))
            )
        else:
            config.folder = folder

        # create dataset and model
        if not isinstance(dataset, Dataset):
            dataset = Dataset.create(config)
        model = KgeModel.create(config, dataset)
        return model
Exemple #15
0
    def create_from(cls,
                    checkpoint: Dict,
                    new_config: Config = None,
                    dataset: Dataset = None,
                    parent_job=None,
                    parameter_client=None) -> Job:
        """
        Creates a Job based on a checkpoint
        Args:
            checkpoint: loaded checkpoint
            new_config: optional config object - overwrites options of config
                              stored in checkpoint
            dataset: dataset object
            parent_job: parent job (e.g. search job)

        Returns: Job based on checkpoint

        """
        from kge.model import KgeModel

        model: KgeModel = None
        # search jobs don't have a model
        if "model" in checkpoint and checkpoint["model"] is not None:
            model = KgeModel.create_from(checkpoint,
                                         new_config=new_config,
                                         dataset=dataset,
                                         parameter_client=parameter_client)
            config = model.config
            dataset = model.dataset
        else:
            config = Config.create_from(checkpoint)
            if new_config:
                config.load_config(new_config)
            dataset = Dataset.create_from(checkpoint, config, dataset)
        job = Job.create(config,
                         dataset,
                         parent_job,
                         model,
                         parameter_client=parameter_client,
                         init_for_load_only=True)
        job._load(checkpoint)
        job.config.log("Loaded checkpoint from {}...".format(
            checkpoint["file"]))
        return job
Exemple #16
0
def main():
    # default config
    config = Config()

    # now parse the arguments
    parser = create_parser(config)
    args, unknown_args = parser.parse_known_args()

    # If there where unknown args, add them to the parser and reparse. The correctness
    # of these arguments will be checked later.
    if len(unknown_args) > 0:
        parser = create_parser(
            config, filter(lambda a: a.startswith("--"), unknown_args)
        )
        args = parser.parse_args()

    # process meta-commands
    process_meta_command(args, "create", {"command": "start", "run": False})
    process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"})
    process_meta_command(
        args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"}
    )
    process_meta_command(
        args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"}
    )
    # dump command
    if args.command == "dump":
        dump(args)
        exit()

    # package command
    if args.command == "package":
        package_model(args)
        exit()

    # start command
    if args.command == "start":
        # use toy config file if no config given
        if args.config is None:
            args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml"
            print(
                "WARNING: No configuration specified; using " + args.config,
                file=sys.stderr,
            )

        if not vars(args)["console.quiet"]:
            print("Loading configuration {}...".format(args.config))
        config.load(args.config)

    # resume command
    if args.command == "resume":
        if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"):
            args.config += "/config.yaml"
        if not vars(args)["console.quiet"]:
            print("Resuming from configuration {}...".format(args.config))
        config.load(args.config)
        config.folder = os.path.dirname(args.config)
        if not config.folder:
            config.folder = "."
        if not os.path.exists(config.folder):
            raise ValueError(
                "{} is not a valid config file for resuming".format(args.config)
            )

    # overwrite configuration with command line arguments
    for key, value in vars(args).items():
        if key in [
            "command",
            "config",
            "run",
            "folder",
            "checkpoint",
            "abort_when_cache_outdated",
        ]:
            continue
        if value is not None:
            if key == "search.device_pool":
                value = "".join(value).split(",")
            try:
                if isinstance(config.get(key), bool):
                    value = argparse_bool_type(value)
            except KeyError:
                pass
            config.set(key, value)
            if key == "model":
                config._import(value)

    # initialize output folder
    if args.command == "start":
        if args.folder is None:  # means: set default
            config_name = os.path.splitext(os.path.basename(args.config))[0]
            config.folder = os.path.join(
                kge_base_dir(),
                "local",
                "experiments",
                datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name,
            )
        else:
            config.folder = args.folder

    # catch errors to log them
    try:
        if args.command == "start" and not config.init_folder():
            raise ValueError("output folder {} exists already".format(config.folder))
        config.log("Using folder: {}".format(config.folder))

        # determine checkpoint to resume (if any)
        if hasattr(args, "checkpoint"):
            checkpoint_file = get_checkpoint_file(config, args.checkpoint)

        # disable processing of outdated cached dataset files globally
        Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated

        # log configuration
        config.log("Configuration:")
        config.log(yaml.dump(config.options), prefix="  ")
        config.log("git commit: {}".format(get_git_revision_short_hash()), prefix="  ")

        # set random seeds
        def get_seed(what):
            seed = config.get(f"random_seed.{what}")
            if seed < 0 and config.get(f"random_seed.default") >= 0:
                import hashlib

                # we add an md5 hash to the default seed so that different PRNGs get a
                # different seed
                seed = (
                    config.get(f"random_seed.default")
                    + int(hashlib.md5(what.encode()).hexdigest(), 16)
                ) % 0xFFFF  # stay 32-bit

            return seed

        if get_seed("python") > -1:
            import random

            random.seed(get_seed("python"))
        if get_seed("torch") > -1:
            import torch

            torch.manual_seed(get_seed("torch"))
        if get_seed("numpy") > -1:
            import numpy.random

            numpy.random.seed(get_seed("numpy"))
        if get_seed("numba") > -1:
            import numpy as np, numba

            @numba.njit
            def seed_numba(seed):
                np.random.seed(seed)

            seed_numba(get_seed("numba"))

        # let's go
        if args.command == "start" and not args.run:
            config.log("Job created successfully.")
        else:
            # load data
            dataset = Dataset.create(config)

            # let's go
            if args.command == "resume":
                if checkpoint_file is not None:
                    checkpoint = load_checkpoint(
                        checkpoint_file, config.get("job.device")
                    )
                    job = Job.create_from(
                        checkpoint, new_config=config, dataset=dataset
                    )
                else:
                    job = Job.create(config, dataset)
                    job.config.log(
                        "No checkpoint found or specified, starting from scratch..."
                    )
            else:
                job = Job.create(config, dataset)
            job.run()
    except BaseException:
        tb = traceback.format_exc()
        config.log(tb, echo=False)
        raise
Exemple #17
0
    def __init__(
        self,
        config: Config,
        dataset: Dataset,
        scorer: Union[RelationalScorer, type],
        create_embedders=True,
        configuration_key=None,
        init_for_load_only=False,
    ):
        super().__init__(config, dataset, configuration_key)

        # TODO support different embedders for subjects and objects

        #: Embedder used for entities (both subject and objects)
        self._entity_embedder: KgeEmbedder

        #: Embedder used for relations
        self._relation_embedder: KgeEmbedder

        if create_embedders:
            self._entity_embedder = KgeEmbedder.create(
                config,
                dataset,
                self.configuration_key + ".entity_embedder",
                dataset.num_entities(),
                init_for_load_only=init_for_load_only,
            )

            #: Embedder used for relations
            num_relations = dataset.num_relations()
            self._relation_embedder = KgeEmbedder.create(
                config,
                dataset,
                self.configuration_key + ".relation_embedder",
                num_relations,
                init_for_load_only=init_for_load_only,
            )

            if not init_for_load_only:
                # load pretrained embeddings
                pretrained_entities_filename = ""
                pretrained_relations_filename = ""
                if self.has_option("entity_embedder.pretrain.model_filename"):
                    pretrained_entities_filename = self.get_option(
                        "entity_embedder.pretrain.model_filename"
                    )
                if self.has_option("relation_embedder.pretrain.model_filename"):
                    pretrained_relations_filename = self.get_option(
                        "relation_embedder.pretrain.model_filename"
                    )

                def load_pretrained_model(
                    pretrained_filename: str,
                ) -> Optional[KgeModel]:
                    if pretrained_filename != "":
                        self.config.log(
                            f"Initializing with embeddings stored in "
                            f"{pretrained_filename}"
                        )
                        checkpoint = load_checkpoint(pretrained_filename)
                        return KgeModel.create_from(checkpoint)
                    return None

                pretrained_entities_model = load_pretrained_model(
                    pretrained_entities_filename
                )
                if pretrained_entities_filename == pretrained_relations_filename:
                    pretrained_relations_model = pretrained_entities_model
                else:
                    pretrained_relations_model = load_pretrained_model(
                        pretrained_relations_filename
                    )
                if pretrained_entities_model is not None:
                    if (
                        pretrained_entities_model.get_s_embedder()
                        != pretrained_entities_model.get_o_embedder()
                    ):
                        raise ValueError(
                            "Can only initialize with pre-trained models having "
                            "identical subject and object embeddings."
                        )
                    self._entity_embedder.init_pretrained(
                        pretrained_entities_model.get_s_embedder()
                    )
                if pretrained_relations_model is not None:
                    self._relation_embedder.init_pretrained(
                        pretrained_relations_model.get_p_embedder()
                    )

        #: Scorer
        self._scorer: RelationalScorer
        if type(scorer) == type:
            # scorer is type of the scorer to use; call its constructor
            self._scorer = scorer(
                config=config, dataset=dataset, configuration_key=self.configuration_key
            )
        else:
            self._scorer = scorer
Exemple #18
0
def main():
    # default config
    config = Config()

    # now parse the arguments
    parser = create_parser(config)
    args, unknown_args = parser.parse_known_args()

    # If there where unknown args, add them to the parser and reparse. The correctness
    # of these arguments will be checked later.
    if len(unknown_args) > 0:
        parser = create_parser(
            config, filter(lambda a: a.startswith("--"), unknown_args)
        )
        args = parser.parse_args()

    # process meta-commands
    process_meta_command(args, "create", {"command": "start", "run": False})
    process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"})
    process_meta_command(
        args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"}
    )
    process_meta_command(
        args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"}
    )
    # dump command
    if args.command == "dump":
        dump(args)
        exit()

    # start command
    if args.command == "start":
        # use toy config file if no config given
        if args.config is None:
            args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml"
            print("WARNING: No configuration specified; using " + args.config)

        print("Loading configuration {}...".format(args.config))
        config.load(args.config)

    # resume command
    if args.command == "resume":
        if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"):
            args.config += "/config.yaml"
        print("Resuming from configuration {}...".format(args.config))
        config.load(args.config)
        config.folder = os.path.dirname(args.config)
        if not config.folder:
            config.folder = "."
        if not os.path.exists(config.folder):
            raise ValueError(
                "{} is not a valid config file for resuming".format(args.config)
            )

    # overwrite configuration with command line arguments
    for key, value in vars(args).items():
        if key in [
            "command",
            "config",
            "run",
            "folder",
            "checkpoint",
            "abort_when_cache_outdated",
        ]:
            continue
        if value is not None:
            if key == "search.device_pool":
                value = "".join(value).split(",")
            try:
                if isinstance(config.get(key), bool):
                    value = argparse_bool_type(value)
            except KeyError:
                pass
            config.set(key, value)
            if key == "model":
                config._import(value)

    # initialize output folder
    if args.command == "start":
        if args.folder is None:  # means: set default
            config_name = os.path.splitext(os.path.basename(args.config))[0]
            config.folder = os.path.join(
                kge_base_dir(),
                "local",
                "experiments",
                datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name,
            )
        else:
            config.folder = args.folder

    # catch errors to log them
    try:
        if args.command == "start" and not config.init_folder():
            raise ValueError("output folder {} exists already".format(config.folder))
        config.log("Using folder: {}".format(config.folder))

        # determine checkpoint to resume (if any)
        if hasattr(args, "checkpoint"):
            if args.checkpoint == "default":
                if config.get("job.type") in ["eval", "valid"]:
                    checkpoint_file = config.checkpoint_file("best")
                else:
                    checkpoint_file = None  # means last
            elif is_number(args.checkpoint, int) or args.checkpoint == "best":
                checkpoint_file = config.checkpoint_file(args.checkpoint)
            else:
                # otherwise, treat it as a filename
                checkpoint_file = args.checkpoint

        # disable processing of outdated cached dataset files globally
        Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated

        # log configuration
        config.log("Configuration:")
        config.log(yaml.dump(config.options), prefix="  ")
        config.log("git commit: {}".format(get_git_revision_short_hash()), prefix="  ")

        # set random seeds
        if config.get("random_seed.python") > -1:
            import random

            random.seed(config.get("random_seed.python"))
        if config.get("random_seed.torch") > -1:
            import torch

            torch.manual_seed(config.get("random_seed.torch"))
        if config.get("random_seed.numpy") > -1:
            import numpy.random

            numpy.random.seed(config.get("random_seed.numpy"))

        # let's go
        if args.command == "start" and not args.run:
            config.log("Job created successfully.")
        else:
            # load data
            dataset = Dataset.load(config)

            # let's go
            job = Job.create(config, dataset)
            if args.command == "resume":
                job.resume(checkpoint_file)
            job.run()
    except BaseException as e:
        tb = traceback.format_exc()
        config.log(tb, echo=False)
        raise e from None
Exemple #19
0
    def __init__(
        self,
        config: Config,
        dataset: Dataset,
        scorer: Union[RelationalScorer, type],
        create_embedders=True,
        configuration_key=None,
        init_for_load_only=False,
        parameter_client=None,
        max_partition_entities=0,
    ):
        super().__init__(config, dataset, configuration_key)

        # TODO support different embedders for subjects and objects

        #: Embedder used for entities (both subject and objects)
        self._entity_embedder: KgeEmbedder

        #: Embedder used for relations
        self._relation_embedder: KgeEmbedder

        if create_embedders:
            self._create_embedders(init_for_load_only)
        elif False:
            #if self.get_option("create_complete"):
            #    embedding_layer_size = dataset.num_entities()
            if config.get("job.distributed.entity_sync_level") == "partition" and max_partition_entities != 0:
                embedding_layer_size =max_partition_entities
            else:
                embedding_layer_size = self._calc_embedding_layer_size(config, dataset)
            config.log(f"creating entity_embedder with {embedding_layer_size} keys")
            self._entity_embedder = KgeEmbedder.create(
                config=config,
                dataset=dataset,
                configuration_key=self.configuration_key + ".entity_embedder",
                #dataset.num_entities(),
                vocab_size=embedding_layer_size,
                init_for_load_only=init_for_load_only,
                parameter_client=parameter_client,
                lapse_offset=0,
                complete_vocab_size=dataset.num_entities()
            )

            #: Embedder used for relations
            num_relations = dataset.num_relations()
            self._relation_embedder = KgeEmbedder.create(
                config,
                dataset,
                self.configuration_key + ".relation_embedder",
                num_relations,
                init_for_load_only=init_for_load_only,
                parameter_client=parameter_client,
                lapse_offset=dataset.num_entities(),
                complete_vocab_size=dataset.num_relations(),
            )

            if not init_for_load_only and parameter_client.rank == get_min_rank(config):
                # load pretrained embeddings
                pretrained_entities_filename = ""
                pretrained_relations_filename = ""
                if self.has_option("entity_embedder.pretrain.model_filename"):
                    pretrained_entities_filename = self.get_option(
                        "entity_embedder.pretrain.model_filename"
                    )
                if self.has_option("relation_embedder.pretrain.model_filename"):
                    pretrained_relations_filename = self.get_option(
                        "relation_embedder.pretrain.model_filename"
                    )

                def load_pretrained_model(
                    pretrained_filename: str,
                ) -> Optional[KgeModel]:
                    if pretrained_filename != "":
                        self.config.log(
                            f"Initializing with embeddings stored in "
                            f"{pretrained_filename}"
                        )
                        checkpoint = load_checkpoint(pretrained_filename)
                        return KgeModel.create_from(checkpoint, parameter_client=parameter_client)
                    return None

                pretrained_entities_model = load_pretrained_model(
                    pretrained_entities_filename
                )
                if pretrained_entities_filename == pretrained_relations_filename:
                    pretrained_relations_model = pretrained_entities_model
                else:
                    pretrained_relations_model = load_pretrained_model(
                        pretrained_relations_filename
                    )
                if pretrained_entities_model is not None:
                    if (
                        pretrained_entities_model.get_s_embedder()
                        != pretrained_entities_model.get_o_embedder()
                    ):
                        raise ValueError(
                            "Can only initialize with pre-trained models having "
                            "identical subject and object embeddings."
                        )
                    self._entity_embedder.init_pretrained(
                        pretrained_entities_model.get_s_embedder()
                    )
                if pretrained_relations_model is not None:
                    self._relation_embedder.init_pretrained(
                        pretrained_relations_model.get_p_embedder()
                    )

        #: Scorer
        self._scorer: RelationalScorer
        if type(scorer) == type:
            # scorer is type of the scorer to use; call its constructor
            self._scorer = scorer(
                config=config, dataset=dataset, configuration_key=self.configuration_key
            )
        else:
            self._scorer = scorer