def create_from( checkpoint: Dict, dataset: Optional[Dataset] = None, use_tmp_log_folder=True, new_config: Config = None, ) -> "KgeModel": """Loads a model from a checkpoint file of a training job or a packaged model. If dataset is specified, associates this dataset with the model. Otherwise uses the dataset used to train the model. If `use_tmp_log_folder` is set, the logs and traces are written to a temporary file. Otherwise, the files `kge.log` and `trace.yaml` will be created (or appended to) in the checkpoint's folder. """ config = Config.create_from(checkpoint) if new_config: config.load_config(new_config) if use_tmp_log_folder: import tempfile config.log_folder = tempfile.mkdtemp(prefix="kge-") else: config.log_folder = checkpoint["folder"] if not config.log_folder or not os.path.exists(config.log_folder): config.log_folder = "." dataset = Dataset.create_from(checkpoint, config, dataset, preload_data=False) model = KgeModel.create(config, dataset, init_for_load_only=True) model.load(checkpoint["model"]) model.eval() return model
def package_model(args): """ Converts a checkpoint to a packaged model. A packaged model only contains the model, entity/relation ids and the config. """ checkpoint_file = args.checkpoint filename = args.file checkpoint = load_checkpoint(checkpoint_file, device="cpu") if checkpoint["type"] != "train": raise ValueError("Can only package trained checkpoints.") config = Config.create_from(checkpoint) dataset = Dataset.create_from(checkpoint, config, preload_data=False) packaged_model = { "type": "package", "model": checkpoint["model"], "epoch": checkpoint["epoch"], "job_id": checkpoint["job_id"], "valid_trace": checkpoint["valid_trace"], } packaged_model = config.save_to(packaged_model) packaged_model = dataset.save_to( packaged_model, ["entity_ids", "relation_ids"], ) if filename is None: output_folder, filename = os.path.split(checkpoint_file) if "checkpoint" in filename: filename = filename.replace("checkpoint", "model") else: filename = filename.split(".pt")[0] + "_package.pt" filename = os.path.join(output_folder, filename) print(f"Saving to {filename}...") torch.save(packaged_model, filename)
def create_from(cls, checkpoint: Dict, new_config: Config = None, dataset: Dataset = None, parent_job=None, parameter_client=None) -> Job: """ Creates a Job based on a checkpoint Args: checkpoint: loaded checkpoint new_config: optional config object - overwrites options of config stored in checkpoint dataset: dataset object parent_job: parent job (e.g. search job) Returns: Job based on checkpoint """ from kge.model import KgeModel model: KgeModel = None # search jobs don't have a model if "model" in checkpoint and checkpoint["model"] is not None: model = KgeModel.create_from(checkpoint, new_config=new_config, dataset=dataset, parameter_client=parameter_client) config = model.config dataset = model.dataset else: config = Config.create_from(checkpoint) if new_config: config.load_config(new_config) dataset = Dataset.create_from(checkpoint, config, dataset) job = Job.create(config, dataset, parent_job, model, parameter_client=parameter_client, init_for_load_only=True) job._load(checkpoint) job.config.log("Loaded checkpoint from {}...".format( checkpoint["file"])) return job