Exemplo n.º 1
0
def package_model(args):
    """
    Converts a checkpoint to a packaged model.
    A packaged model only contains the model, entity/relation ids and the config.
    """
    checkpoint_file = args.checkpoint
    filename = args.file
    checkpoint = load_checkpoint(checkpoint_file, device="cpu")
    if checkpoint["type"] != "train":
        raise ValueError("Can only package trained checkpoints.")
    config = Config.create_from(checkpoint)
    dataset = Dataset.create_from(checkpoint, config, preload_data=False)
    packaged_model = {
        "type": "package",
        "model": checkpoint["model"],
        "epoch": checkpoint["epoch"],
        "job_id": checkpoint["job_id"],
        "valid_trace": checkpoint["valid_trace"],
    }
    packaged_model = config.save_to(packaged_model)
    packaged_model = dataset.save_to(
        packaged_model,
        ["entity_ids", "relation_ids"],
    )
    if filename is None:
        output_folder, filename = os.path.split(checkpoint_file)
        if "checkpoint" in filename:
            filename = filename.replace("checkpoint", "model")
        else:
            filename = filename.split(".pt")[0] + "_package.pt"
        filename = os.path.join(output_folder, filename)
    print(f"Saving to {filename}...")
    torch.save(packaged_model, filename)
Exemplo n.º 2
0
 def load_pretrained_model(
     pretrained_filename: str, ) -> Optional[KgeModel]:
     if pretrained_filename != "":
         self.config.log(
             f"Initializing with embeddings stored in "
             f"{pretrained_filename}")
         checkpoint = load_checkpoint(pretrained_filename)
         return KgeModel.create_from(checkpoint)
     return None