Example #1
0
    def create_from(
        checkpoint: Dict,
        dataset: Optional[Dataset] = None,
        use_tmp_log_folder=True,
        new_config: Config = None,
    ) -> "KgeModel":
        """Loads a model from a checkpoint file of a training job or a packaged model.

        If dataset is specified, associates this dataset with the model. Otherwise uses
        the dataset used to train the model.

        If `use_tmp_log_folder` is set, the logs and traces are written to a temporary
        file. Otherwise, the files `kge.log` and `trace.yaml` will be created (or
        appended to) in the checkpoint's folder.

        """
        config = Config.create_from(checkpoint)
        if new_config:
            config.load_config(new_config)

        if use_tmp_log_folder:
            import tempfile

            config.log_folder = tempfile.mkdtemp(prefix="kge-")
        else:
            config.log_folder = checkpoint["folder"]
            if not config.log_folder or not os.path.exists(config.log_folder):
                config.log_folder = "."
        dataset = Dataset.create_from(checkpoint, config, dataset, preload_data=False)
        model = KgeModel.create(config, dataset, init_for_load_only=True)
        model.load(checkpoint["model"])
        model.eval()
        return model
Example #2
0
def package_model(args):
    """
    Converts a checkpoint to a packaged model.
    A packaged model only contains the model, entity/relation ids and the config.
    """
    checkpoint_file = args.checkpoint
    filename = args.file
    checkpoint = load_checkpoint(checkpoint_file, device="cpu")
    if checkpoint["type"] != "train":
        raise ValueError("Can only package trained checkpoints.")
    config = Config.create_from(checkpoint)
    dataset = Dataset.create_from(checkpoint, config, preload_data=False)
    packaged_model = {
        "type": "package",
        "model": checkpoint["model"],
        "epoch": checkpoint["epoch"],
        "job_id": checkpoint["job_id"],
        "valid_trace": checkpoint["valid_trace"],
    }
    packaged_model = config.save_to(packaged_model)
    packaged_model = dataset.save_to(
        packaged_model,
        ["entity_ids", "relation_ids"],
    )
    if filename is None:
        output_folder, filename = os.path.split(checkpoint_file)
        if "checkpoint" in filename:
            filename = filename.replace("checkpoint", "model")
        else:
            filename = filename.split(".pt")[0] + "_package.pt"
        filename = os.path.join(output_folder, filename)
    print(f"Saving to {filename}...")
    torch.save(packaged_model, filename)
Example #3
0
    def create_from(cls,
                    checkpoint: Dict,
                    new_config: Config = None,
                    dataset: Dataset = None,
                    parent_job=None,
                    parameter_client=None) -> Job:
        """
        Creates a Job based on a checkpoint
        Args:
            checkpoint: loaded checkpoint
            new_config: optional config object - overwrites options of config
                              stored in checkpoint
            dataset: dataset object
            parent_job: parent job (e.g. search job)

        Returns: Job based on checkpoint

        """
        from kge.model import KgeModel

        model: KgeModel = None
        # search jobs don't have a model
        if "model" in checkpoint and checkpoint["model"] is not None:
            model = KgeModel.create_from(checkpoint,
                                         new_config=new_config,
                                         dataset=dataset,
                                         parameter_client=parameter_client)
            config = model.config
            dataset = model.dataset
        else:
            config = Config.create_from(checkpoint)
            if new_config:
                config.load_config(new_config)
            dataset = Dataset.create_from(checkpoint, config, dataset)
        job = Job.create(config,
                         dataset,
                         parent_job,
                         model,
                         parameter_client=parameter_client,
                         init_for_load_only=True)
        job._load(checkpoint)
        job.config.log("Loaded checkpoint from {}...".format(
            checkpoint["file"]))
        return job