Пример #1
0
    def create_from(cls,
                    checkpoint: Dict,
                    new_config: Config = None,
                    dataset: Dataset = None,
                    parent_job=None,
                    parameter_client=None) -> Job:
        """
        Creates a Job based on a checkpoint
        Args:
            checkpoint: loaded checkpoint
            new_config: optional config object - overwrites options of config
                              stored in checkpoint
            dataset: dataset object
            parent_job: parent job (e.g. search job)

        Returns: Job based on checkpoint

        """
        from kge.model import KgeModel

        model: KgeModel = None
        # search jobs don't have a model
        if "model" in checkpoint and checkpoint["model"] is not None:
            model = KgeModel.create_from(checkpoint,
                                         new_config=new_config,
                                         dataset=dataset,
                                         parameter_client=parameter_client)
            config = model.config
            dataset = model.dataset
        else:
            config = Config.create_from(checkpoint)
            if new_config:
                config.load_config(new_config)
            dataset = Dataset.create_from(checkpoint, config, dataset)
        job = Job.create(config,
                         dataset,
                         parent_job,
                         model,
                         parameter_client=parameter_client,
                         init_for_load_only=True)
        job._load(checkpoint)
        job.config.log("Loaded checkpoint from {}...".format(
            checkpoint["file"]))
        return job
import torch
from kge.model import KgeModel
from kge.util.io import load_checkpoint
import numpy as np
# Link prediction performances of RESCAL, ComplEx, ConvE, DistMult and TransE on WN18RR* (out-of-vocabulary entities are removed)
models = ['rescal', 'complex', 'conve', 'distmult', 'transe']

for m in models:
    if m == 'conex':
        """ """
        raise NotImplementedError()
    else:
        # 1. Load pretrained model via LibKGE
        checkpoint = load_checkpoint(
            f'pretrained_models/FB15K-237/fb15k-237-{m}.pt')
        model = KgeModel.create_from(checkpoint)

        # 3. Create mappings.
        # 3.1 Entity index mapping.
        entity_idxs = {
            e: e_idx
            for e, e_idx in zip(model.dataset.entity_ids(),
                                range(len(model.dataset.entity_ids())))
        }
        # 3.2 Relation index mapping.
        relation_idxs = {
            r: r_idx
            for r, r_idx in zip(model.dataset.relation_ids(),
                                range(len(model.dataset.relation_ids())))
        }
    # 2. Load Dataset
Пример #3
0
def train(data_path,
          neg_batch_size,
          batch_size,
          shuffle,
          num_workers,
          nb_epochs,
          embedding_dim,
          hidden_dim,
          relation_dim,
          gpu,
          use_cuda,
          patience,
          freeze,
          validate_every,
          hops,
          lr,
          entdrop,
          reldrop,
          scoredrop,
          l3_reg,
          model_name,
          decay,
          ls,
          load_from,
          outfile,
          do_batch_norm,
          valid_data_path=None):
    print('Loading entities and relations')
    kg_type = 'full'
    if 'half' in hops:
        kg_type = 'half'
    checkpoint_file = '../../pretrained_models/embeddings/ComplEx_fbwq_' + kg_type + '/checkpoint_best.pt'
    print('Loading kg embeddings from', checkpoint_file)
    kge_checkpoint = load_checkpoint(checkpoint_file)
    kge_model = KgeModel.create_from(kge_checkpoint)
    kge_model.eval()
    e = getEntityEmbeddings(kge_model, hops)

    print('Loaded entities and relations')

    entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e)
    data = process_text_file(data_path, split=False)
    print('Train file processed, making dataloader')
    # word2ix,idx2word, max_len = get_vocab(data)
    # hops = str(num_hops)
    device = torch.device(gpu if use_cuda else "cpu")
    dataset = DatasetMetaQA(data, e, entity2idx)
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=num_workers)
    print('Creating model...')
    model = RelationExtractor(embedding_dim=embedding_dim,
                              num_entities=len(idx2entity),
                              relation_dim=relation_dim,
                              pretrained_embeddings=embedding_matrix,
                              freeze=freeze,
                              device=device,
                              entdrop=entdrop,
                              reldrop=reldrop,
                              scoredrop=scoredrop,
                              l3_reg=l3_reg,
                              model=model_name,
                              ls=ls,
                              do_batch_norm=do_batch_norm)
    print('Model created!')
    if load_from != '':
        # model.load_state_dict(torch.load("checkpoints/roberta_finetune/" + load_from + ".pt"))
        fname = "checkpoints/roberta_finetune/" + load_from + ".pt"
        model.load_state_dict(
            torch.load(fname, map_location=lambda storage, loc: storage))
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ExponentialLR(optimizer, decay)
    optimizer.zero_grad()
    best_score = -float("inf")
    best_model = model.state_dict()
    no_update = 0
    # time.sleep(10)
    for epoch in range(nb_epochs):
        phases = []
        for i in range(validate_every):
            phases.append('train')
        phases.append('valid')
        for phase in phases:
            if phase == 'train':
                model.train()
                # model.apply(set_bn_eval)
                loader = tqdm(data_loader,
                              total=len(data_loader),
                              unit="batches")
                running_loss = 0
                for i_batch, a in enumerate(loader):
                    model.zero_grad()
                    question_tokenized = a[0].to(device)
                    attention_mask = a[1].to(device)
                    positive_head = a[2].to(device)
                    positive_tail = a[3].to(device)
                    loss = model(question_tokenized=question_tokenized,
                                 attention_mask=attention_mask,
                                 p_head=positive_head,
                                 p_tail=positive_tail)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    loader.set_postfix(Loss=running_loss /
                                       ((i_batch + 1) * batch_size),
                                       Epoch=epoch)
                    loader.set_description('{}/{}'.format(epoch, nb_epochs))
                    loader.update()

                scheduler.step()

            elif phase == 'valid':
                model.eval()
                eps = 0.0001
                answers, score = validate_v2(model=model,
                                             data_path=valid_data_path,
                                             entity2idx=entity2idx,
                                             train_dataloader=dataset,
                                             device=device,
                                             model_name=model_name)
                if score > best_score + eps:
                    best_score = score
                    no_update = 0
                    best_model = model.state_dict()
                    print(
                        hops +
                        " hop Validation accuracy (no relation scoring) increased from previous epoch",
                        score)
                    # writeToFile(answers, 'results_' + model_name + '_' + hops + '.txt')
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                elif (score < best_score + eps) and (no_update < patience):
                    no_update += 1
                    print(
                        "Validation accuracy decreases to %f from %f, %d more epoch to check"
                        % (score, best_score, patience - no_update))
                elif no_update == patience:
                    print(
                        "Model has exceed patience. Saving best model and exiting"
                    )
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                    exit()
                if epoch == nb_epochs - 1:
                    print("Final Epoch has reached. Stoping and saving model.")
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                    exit()
Пример #4
0
def perform_experiment(data_path,
                       mode,
                       neg_batch_size,
                       batch_size,
                       shuffle,
                       num_workers,
                       nb_epochs,
                       embedding_dim,
                       hidden_dim,
                       relation_dim,
                       gpu,
                       use_cuda,
                       patience,
                       freeze,
                       validate_every,
                       hops,
                       lr,
                       entdrop,
                       reldrop,
                       scoredrop,
                       l3_reg,
                       model_name,
                       decay,
                       ls,
                       load_from,
                       outfile,
                       do_batch_norm,
                       que_embedding_model,
                       valid_data_path=None,
                       test_data_path=None):
    webqsp_checkpoint_folder = f"../../checkpoints/WebQSP/{model_name}_{que_embedding_model}_{outfile}/"
    if not os.path.exists(webqsp_checkpoint_folder):
        os.makedirs(webqsp_checkpoint_folder)

    print('Loading entities and relations')
    kg_type = 'full'
    if 'half' in hops:
        kg_type = 'half'

    checkpoint_file = f"../../pretrained_models/embeddings/{model_name}_fbwq_{kg_type}/checkpoint_best.pt"

    print('Loading kg embeddings from', checkpoint_file)
    kge_checkpoint = load_checkpoint(checkpoint_file)
    kge_model = KgeModel.create_from(kge_checkpoint)
    kge_model.eval()
    e = getEntityEmbeddings(model_name, kge_model, hops)

    print('Loaded entities and relations')

    entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e)

    # word2ix,idx2word, max_len = get_vocab(data)
    # hops = str(num_hops)
    device = torch.device(gpu if use_cuda else "cpu")
    model = RelationExtractor(embedding_dim=embedding_dim,
                              num_entities=len(idx2entity),
                              relation_dim=relation_dim,
                              pretrained_embeddings=embedding_matrix,
                              freeze=freeze,
                              device=device,
                              entdrop=entdrop,
                              reldrop=reldrop,
                              scoredrop=scoredrop,
                              l3_reg=l3_reg,
                              model=model_name,
                              que_embedding_model=que_embedding_model,
                              ls=ls,
                              do_batch_norm=do_batch_norm)

    # time.sleep(10)
    if mode == 'train':
        data = process_text_file(data_path)
        dataset = DatasetWebQSP(data, e, entity2idx, que_embedding_model,
                                model_name)

        # if model_name=="ComplEx":
        #     data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        # else:
        #     data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=custom_collate_fn)

        data_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=num_workers)

        if load_from != '':
            # model.load_state_dict(torch.load("checkpoints/roberta_finetune/" + load_from + ".pt"))
            fname = f"checkpoints/{que_embedding_model}_finetune/{load_from}.pt"
            model.load_state_dict(
                torch.load(fname, map_location=lambda storage, loc: storage))
        model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = ExponentialLR(optimizer, decay)
        optimizer.zero_grad()
        best_score = -float("inf")
        best_model = model.state_dict()
        no_update = 0
        for epoch in range(nb_epochs):
            phases = []
            for i in range(validate_every):
                phases.append('train')
            phases.append('valid')
            for phase in phases:
                if phase == 'train':
                    model.train()
                    # model.apply(set_bn_eval)
                    loader = tqdm(data_loader,
                                  total=len(data_loader),
                                  unit="batches")
                    running_loss = 0
                    for i_batch, a in enumerate(loader):
                        model.zero_grad()
                        question_tokenized = a[0].to(device)
                        attention_mask = a[1].to(device)
                        positive_head = a[2].to(device)
                        positive_tail = a[3].to(device)
                        loss = model(question_tokenized=question_tokenized,
                                     attention_mask=attention_mask,
                                     p_head=positive_head,
                                     p_tail=positive_tail)
                        loss.backward()
                        optimizer.step()
                        running_loss += loss.item()
                        loader.set_postfix(Loss=running_loss /
                                           ((i_batch + 1) * batch_size),
                                           Epoch=epoch)
                        loader.set_description('{}/{}'.format(
                            epoch, nb_epochs))
                        loader.update()

                    scheduler.step()

                elif phase == 'valid':
                    model.eval()
                    eps = 0.0001
                    answers, score = test(model=model,
                                          data_path=valid_data_path,
                                          entity2idx=entity2idx,
                                          dataloader=dataset,
                                          device=device,
                                          model_name=model_name,
                                          return_hits_at_k=False)
                    if score > best_score + eps:
                        best_score = score
                        no_update = 0
                        best_model = model.state_dict()
                        print(
                            hops +
                            " hop Validation accuracy (no relation scoring) increased from previous epoch",
                            score)
                        writeToFile(
                            answers,
                            f'results/{model_name}_{que_embedding_model}_{outfile}.txt'
                        )
                        torch.save(
                            best_model,
                            get_chkpt_path(model_name, que_embedding_model,
                                           outfile))
                    elif (score < best_score + eps) and (no_update < patience):
                        no_update += 1
                        print(
                            "Validation accuracy decreases to %f from %f, %d more epoch to check"
                            % (score, best_score, patience - no_update))
                    elif no_update == patience:
                        print(
                            "Model has exceed patience. Saving best model and exiting"
                        )
                        torch.save(
                            best_model,
                            get_chkpt_path(model_name, que_embedding_model,
                                           outfile))
                        exit(0)
                    if epoch == nb_epochs - 1:
                        print(
                            "Final Epoch has reached. Stoping and saving model."
                        )
                        torch.save(
                            best_model,
                            get_chkpt_path(model_name, que_embedding_model,
                                           outfile))
                        exit()
                    # torch.save(model.state_dict(), "checkpoints/roberta_finetune/"+str(epoch)+".pt")
                    # torch.save(model.state_dict(), "checkpoints/roberta_finetune/x.pt")

    elif mode == 'test':
        data = process_text_file(test_data_path)
        dataset = DatasetWebQSP(data, e, entity2idx, que_embedding_model,
                                model_name)
        model_chkpt_file_path = get_chkpt_path(model_name, que_embedding_model,
                                               outfile)
        model.load_state_dict(
            torch.load(model_chkpt_file_path,
                       map_location=lambda storage, loc: storage))
        model.to(device)
        for parameter in model.parameters():
            parameter.requires_grad = False
        model.eval()
        answers, accuracy, hits_at_1, hits_at_5, hits_at_10 = test(
            model=model,
            data_path=test_data_path,
            entity2idx=entity2idx,
            dataloader=dataset,
            device=device,
            model_name=model_name,
            return_hits_at_k=True)

        d = {
            'KG-Model': model_name,
            'KG-Type': kg_type,
            'Que-Embedding-Model': que_embedding_model,
            'Accuracy': [accuracy],
            'Hits@1': [hits_at_1],
            'Hits@5': [hits_at_5],
            'Hits@10': [hits_at_10]
        }
        df = pd.DataFrame(data=d)
        df.to_csv(f"final_results.csv", mode='a', index=False, header=False)