Esempio n. 1
0
def _run_train_job(sicnk, device=None):
    """Runs a training job and returns the trace entry of its best validation result.

    Also takes are of appropriate tracing.

    """

    search_job, train_job_index, train_job_config, train_job_count, trace_keys = sicnk

    try:
        # load the job
        if device is not None:
            train_job_config.set("job.device", device)
        search_job.config.log(
            "Starting training job {} ({}/{}) on device {}...".format(
                train_job_config.folder,
                train_job_index + 1,
                train_job_count,
                train_job_config.get("job.device"),
            ))
        checkpoint_file = get_checkpoint_file(train_job_config)
        if checkpoint_file is not None:
            checkpoint = load_checkpoint(checkpoint_file,
                                         train_job_config.get("job.device"))
            job = Job.create_from(
                checkpoint=checkpoint,
                new_config=train_job_config,
                dataset=search_job.dataset,
                parent_job=search_job,
            )
        else:
            job = Job.create(
                config=train_job_config,
                dataset=search_job.dataset,
                parent_job=search_job,
            )

        # process the trace entries to far (in case of a resumed job)
        metric_name = search_job.config.get("valid.metric")
        valid_trace = []

        def copy_to_search_trace(job, trace_entry=None):
            if trace_entry is None:
                trace_entry = job.valid_trace[-1]
            trace_entry = copy.deepcopy(trace_entry)
            for key in trace_keys:
                # Process deprecated options to some extent. Support key renames, but
                # not value renames.
                actual_key = {key: None}
                _process_deprecated_options(actual_key)
                if len(actual_key) > 1:
                    raise KeyError(
                        f"{key} is deprecated but cannot be handled automatically"
                    )
                actual_key = next(iter(actual_key.keys()))
                value = train_job_config.get(actual_key)
                trace_entry[key] = value

            trace_entry["folder"] = os.path.split(train_job_config.folder)[1]
            metric_value = Trace.get_metric(trace_entry, metric_name)
            trace_entry["metric_name"] = metric_name
            trace_entry["metric_value"] = metric_value
            trace_entry["parent_job_id"] = search_job.job_id
            search_job.config.trace(**trace_entry)
            valid_trace.append(trace_entry)

        for trace_entry in job.valid_trace:
            copy_to_search_trace(None, trace_entry)

        # run the job (adding new trace entries as we go)
        # TODO make this less hacky (easier once integrated into SearchJob)
        from kge.job import ManualSearchJob

        if not isinstance(
                search_job,
                ManualSearchJob) or search_job.config.get("manual_search.run"):
            job.post_valid_hooks.append(copy_to_search_trace)
            job.run()
        else:
            search_job.config.log(
                "Skipping running of training job as requested by user.")
            return (train_job_index, None, None)

        # analyze the result
        search_job.config.log("Best result in this training job:")
        best = None
        best_metric = None
        for trace_entry in valid_trace:
            metric = trace_entry["metric_value"]
            if not best or Metric(search_job).better(metric, best_metric):
                best = trace_entry
                best_metric = metric

        # record the best result of this job
        best["child_job_id"] = best["job_id"]
        for k in ["job", "job_id", "type", "parent_job_id", "scope", "event"]:
            if k in best:
                del best[k]
        search_job.trace(
            event="search_completed",
            echo=True,
            echo_prefix="  ",
            log=True,
            scope="train",
            **best,
        )

        # force releasing the GPU memory of the job to avoid memory leakage
        del job
        gc.collect()

        return (train_job_index, best, best_metric)
    except BaseException as e:
        search_job.config.log("Trial {:05d} failed: {}".format(
            train_job_index, repr(e)))
        if search_job.on_error == "continue":
            return (train_job_index, None, None)
        else:
            search_job.config.log(
                "Aborting search due to failure of trial {:05d}".format(
                    train_job_index))
            raise e
Esempio n. 2
0
def main():
    args = parse_args()

    # Load model checkpoint and data
    checkpoint = load_checkpoint(args.model_checkpoint, device="cpu")
    model_pt = kge.model.KgeModel.create_from(checkpoint)
    print("Loaded model from", args.model_checkpoint)

    dataset = model_pt.dataset

    # Load all data
    train_spo, valid_spo, test_spo = [
        dataset.split(split) for split in ("train", "valid", "test")
    ]
    all_spo = torch.cat((train_spo, valid_spo, test_spo), axis=0).long()

    # Load relation ID to string mapping
    relation_ids = dataset.relation_ids()
    metric_names = ("mrr", "hits@10")
    metrics_all = defaultdict(lambda: defaultdict(list))
    dfs = []

    # Keep track of percentage of test triples per relation type
    for rid in tqdm(torch.unique(test_spo[:, 1]), desc="Relation"):
        rid = rid.item()

        # Get all test triples with this relation
        test_filt = test_spo[test_spo[:, 1] == rid]

        for direction in ["s", "o"]:  # (?, r, t) and (h, r, ?)
            metrics_mean = defaultdict(dict)

            for modelname, score_fn in zip(
                ["Model", "Baseline"], [score_with_model, score_by_frequency]):

                # score test triples and evaluate rankings
                scores = score_fn(model_pt, test_filt, direction=direction)
                model_metrics = evaluate_rankings(scores,
                                                  test_filt,
                                                  all_spo,
                                                  direction=direction)

                for metric_name, metric in zip(metric_names, model_metrics):
                    metrics_mean[modelname][metric_name] = np.mean(metric)
                    metrics_all[modelname][metric_name].extend(metric)

            for metric_name in metric_names:
                model_metric = metrics_mean["Model"][metric_name]
                baseline_metric = metrics_mean["Baseline"][metric_name]
                diff = model_metric - baseline_metric

                line = dict(
                    relation=relation_ids[rid],
                    metric=metric_name,
                    direction=direction,
                    count=len(test_filt),
                    diff=diff,
                    model=model_metric,
                    baseline=baseline_metric,
                )

                if args.csv is not None:
                    dfs.append(
                        pd.DataFrame.from_dict(line,
                                               orient="index").transpose())

    if args.csv is not None:
        df = pd.concat(dfs)
        df.to_csv(args.csv, index=False)
        print("Saved results to", args.csv)

    for modelname in metrics_all:
        for metric, scores in metrics_all[modelname].items():
            print(modelname, metric, np.mean(scores))
Esempio n. 3
0
def main():
    args = parse_args()

    # Load first model, get dataset
    # Assumes all models trained on same data
    checkpoint = load_checkpoint(args.model_files[0], device="cpu")
    model = kge.model.KgeModel.create_from(checkpoint)
    dataset = model.dataset

    splits = ("valid", "test")
    valid_spo, test_spo = [dataset.split(split).long() for split in splits]

    if args.negative in ("uniform", "frequency"):
        valid_neg_spo, test_neg_spo = [
            generate_neg_spo(dataset, split, negative_type=args.negative)
            for split in splits
        ]
    else:
        valid_neg_spo, test_neg_spo = load_neg_spo(dataset, size=args.size)
        print(
            f"Loaded {len(valid_neg_spo)} valid negatives",
            f"and {len(test_neg_spo)} test negatives",
        )

    valid_spo_all = torch.cat((valid_spo, valid_neg_spo))
    test_spo_all = torch.cat((test_spo, test_neg_spo))

    metrics = []
    dfs = []

    for model_file in args.model_files:
        if os.path.exists(model_file):
            checkpoint = load_checkpoint(model_file, device="cpu")
            model = kge.model.KgeModel.create_from(checkpoint)

            # Score negative and positive validation triples
            X_valid, y_valid = get_X_y(model, valid_spo, valid_neg_spo)
            X_test, y_test = get_X_y(model, test_spo, test_neg_spo)

            valid_relations = valid_spo_all[:, 1].unique()
            test_relations = test_spo_all[:, 1].unique()

            y_pred_valid = torch.zeros(y_valid.shape,
                                       dtype=torch.long,
                                       device="cpu")
            y_pred_test = torch.zeros(y_test.shape,
                                      dtype=torch.long,
                                      device="cpu")

            ############################################################################
            # begin credits to https://github.com/uma-pi1/kge/blob/triple_classification/kge/job/triple_classification.py#L302 #
            ############################################################################
            REL_KEY = -1
            thresholds = {
                r: -float("inf")
                for r in range(dataset.num_relations())
            }
            thresholds[REL_KEY] = -float("inf")

            for r in valid_relations:  # set a threshold for each relation
                current_rel = valid_spo_all[:, 1] == r
                threshold = get_threshold(X_valid[current_rel],
                                          y_valid[current_rel])
                thresholds[r.item()] = threshold

                predictions = X_valid[current_rel] >= threshold
                y_pred_valid[current_rel] = predictions.view(-1).long()

            # also set a global threshold for relations unseen in valid set
            thresholds[REL_KEY] = get_threshold(X_valid, y_valid)

            for r in test_relations:  # get predictions based on validation thresholds
                key = r.item() if r.item() in thresholds else REL_KEY
                threshold = thresholds[key]

                current_rel = test_spo_all[:, 1] == r
                predictions = X_test[current_rel] >= threshold

                y_pred_test[current_rel] = predictions.view(-1).long()
            ############################################################################
            #                                end credits                               #
            ############################################################################

            y_test = y_test.numpy()
            y_pred_test = y_pred_test.numpy()

            line = dict(
                valid_accuracy=accuracy_score(y_valid, y_pred_valid),
                valid_f1=f1_score(y_valid, y_pred_valid),
                test_accuracy=accuracy_score(y_test, y_pred_test),
                test_f1=f1_score(y_test, y_pred_test),
                model_file=model_file,
            )

            metrics.append(line)

            if args.csv is not None:
                dfs.append(
                    pd.DataFrame.from_dict(line, orient="index").transpose())

    if args.csv is not None:
        df = pd.concat(dfs)
        df.to_csv(args.csv, index=False)
        print("Saved results to", args.csv)

    for metric in metrics:
        for key, val in metric.items():
            print(f"{key}: {val}")
        print()
from util.eval import Evaluator
import json
import torch
from kge.model import KgeModel
from kge.util.io import load_checkpoint
import numpy as np
# Link prediction performances of RESCAL, ComplEx, ConvE, DistMult and TransE on WN18RR* (out-of-vocabulary entities are removed)
models = ['rescal', 'complex', 'conve', 'distmult', 'transe']

for m in models:
    if m == 'conex':
        """ """
        raise NotImplementedError()
    else:
        # 1. Load pretrained model via LibKGE
        checkpoint = load_checkpoint(
            f'pretrained_models/FB15K-237/fb15k-237-{m}.pt')
        model = KgeModel.create_from(checkpoint)

        # 3. Create mappings.
        # 3.1 Entity index mapping.
        entity_idxs = {
            e: e_idx
            for e, e_idx in zip(model.dataset.entity_ids(),
                                range(len(model.dataset.entity_ids())))
        }
        # 3.2 Relation index mapping.
        relation_idxs = {
            r: r_idx
            for r, r_idx in zip(model.dataset.relation_ids(),
                                range(len(model.dataset.relation_ids())))
        }
Esempio n. 5
0
from util.eval import Evaluator
import json
import torch
from kge.model import KgeModel
from kge.util.io import load_checkpoint
import numpy as np
# Link prediction performances of RESCAL, ComplEx, ConvE, DistMult and TransE on WN18RR* (out-of-vocabulary entities are removed)
models = ['rescal', 'complex', 'conve', 'distmult', 'transe']

for m in models:
    if m == 'conex':
        """ """
        raise NotImplementedError()
    else:
        # 1. Load pretrained model via LibKGE
        checkpoint = load_checkpoint(f'pretrained_models/WN18RR/wnrr-{m}.pt')
        model = KgeModel.create_from(checkpoint)

        # 3. Create mappings.
        # 3.1 Entity index mapping.
        entity_idxs = {
            e: e_idx
            for e, e_idx in zip(model.dataset.entity_ids(),
                                range(len(model.dataset.entity_ids())))
        }
        # 3.2 Relation index mapping.
        relation_idxs = {
            r: r_idx
            for r, r_idx in zip(model.dataset.relation_ids(),
                                range(len(model.dataset.relation_ids())))
        }
Esempio n. 6
0
    def run(self):
        torch_device = self.config.get("job.device")
        if self.config.get("job.device") == "cuda":
            torch_device = "cuda:0"
        if torch_device != "cpu":
            torch.cuda.set_device(torch_device)
        # seeds need to be set in every process
        set_seeds(self.config, self.rank)

        os.environ["MASTER_ADDR"] = self.config.get("job.distributed.master_ip")
        os.environ["MASTER_PORT"] = self.config.get("job.distributed.master_port")
        min_rank = get_min_rank(self.config)
        print("before init", self.rank + min_rank)
        dist.init_process_group(
            backend="gloo",
            init_method="env://",
            world_size=self.num_total_workers + min_rank,
            rank=self.rank + min_rank,
            timeout=datetime.timedelta(hours=6),
        )
        worker_ranks = list(range(min_rank, self.num_total_workers+min_rank))
        worker_group = dist.new_group(worker_ranks, timeout=datetime.timedelta(hours=6))

        # create parameter server
        server = None
        if self.config.get("job.distributed.parameter_server") == "lapse":
            os.environ["DMLC_NUM_WORKER"] = "0"
            os.environ["DMLC_NUM_SERVER"] = str(self.num_total_workers)
            os.environ["DMLC_ROLE"] = "server"
            os.environ["DMLC_PS_ROOT_URI"] = self.config.get(
                "job.distributed.master_ip"
            )
            os.environ["DMLC_PS_ROOT_PORT"] = self.config.get(
                "job.distributed.lapse_port"
            )

            num_workers_per_server = 1
            lapse.setup(self.num_keys, num_workers_per_server)
            server = lapse.Server(self.num_keys, self.embedding_dim + self.optimizer_dim)
        elif self.config.get("job.distributed.parameter_server") == "shared":
            server = self.parameters

        # create train-worker config, dataset and folder
        device_pool: list = self.config.get("job.device_pool")
        if len(device_pool) == 0:
            device_pool.append(self.config.get("job.device"))
        worker_id = self.rank
        config = deepcopy(self.config)
        config.set("job.device", device_pool[worker_id % len(device_pool)])
        config.folder = os.path.join(self.config.folder, f"worker-{self.rank}")
        config.init_folder()
        dataset = deepcopy(self.dataset)

        parameter_client = KgeParameterClient.create(
            client_type=self.config.get("job.distributed.parameter_server"),
            server_id=0,
            client_id=worker_id + min_rank,
            embedding_dim=self.embedding_dim + self.optimizer_dim,
            server=server,
            num_keys=self.num_keys,
            num_meta_keys=self.num_meta_keys,
            worker_group=worker_group,
        )
        # don't re-initialize the model after loading checkpoint
        init_for_load_only = self.checkpoint_name is not None
        job = Job.create(
            config=config,
            dataset=dataset,
            parameter_client=parameter_client,
            init_for_load_only=init_for_load_only,
        )
        if self.checkpoint_name is not None:
            checkpoint = load_checkpoint(self.checkpoint_name)
            job._load(checkpoint)
            job.load_distributed(checkpoint_name=self.checkpoint_name)

        job.run()

        # all done, clean up
        print("shut down everything")
        parameter_client.barrier()
        if hasattr(job, "work_scheduler_client"):
            job.work_scheduler_client.shutdown()
        parameter_client.shutdown()
        # delete all occurrences of the parameter client to properly shutdown lapse
        # del job
        del job.parameter_client
        del job.model.get_s_embedder().parameter_client
        del job.model.get_p_embedder().parameter_client
        del job.model
        if hasattr(job, "optimizer"):
            del job.optimizer
        del parameter_client
        gc.collect()  # make sure lapse-worker destructor is called
        # shutdown server
        if server is not None and type(server) != torch.Tensor:
            server.shutdown()
        if self.result_pipe is not None:
            if hasattr(job, "valid_trace"):
                # if we valid from checkpoint there is no valid trace
                self.result_pipe.send(job.valid_trace)
            else:
                self.result_pipe.send(None)
Esempio n. 7
0
def main():
    # default config
    config = Config()

    # now parse the arguments
    parser = create_parser(config)
    args, unknown_args = parser.parse_known_args()

    # If there where unknown args, add them to the parser and reparse. The correctness
    # of these arguments will be checked later.
    if len(unknown_args) > 0:
        parser = create_parser(
            config, filter(lambda a: a.startswith("--"), unknown_args)
        )
        args = parser.parse_args()

    # process meta-commands
    process_meta_command(args, "create", {"command": "start", "run": False})
    process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"})
    process_meta_command(
        args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"}
    )
    process_meta_command(
        args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"}
    )
    # dump command
    if args.command == "dump":
        dump(args)
        exit()

    # package command
    if args.command == "package":
        package_model(args)
        exit()

    # start command
    if args.command == "start":
        # use toy config file if no config given
        if args.config is None:
            args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml"
            print(
                "WARNING: No configuration specified; using " + args.config,
                file=sys.stderr,
            )

        if not vars(args)["console.quiet"]:
            print("Loading configuration {}...".format(args.config))
        config.load(args.config)

    # resume command
    if args.command == "resume":
        if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"):
            args.config += "/config.yaml"
        if not vars(args)["console.quiet"]:
            print("Resuming from configuration {}...".format(args.config))
        config.load(args.config)
        config.folder = os.path.dirname(args.config)
        if not config.folder:
            config.folder = "."
        if not os.path.exists(config.folder):
            raise ValueError(
                "{} is not a valid config file for resuming".format(args.config)
            )

    # overwrite configuration with command line arguments
    for key, value in vars(args).items():
        if key in [
            "command",
            "config",
            "run",
            "folder",
            "checkpoint",
            "abort_when_cache_outdated",
        ]:
            continue
        if value is not None:
            if key == "search.device_pool":
                value = "".join(value).split(",")
            try:
                if isinstance(config.get(key), bool):
                    value = argparse_bool_type(value)
            except KeyError:
                pass
            config.set(key, value)
            if key == "model":
                config._import(value)

    # initialize output folder
    if args.command == "start":
        if args.folder is None:  # means: set default
            config_name = os.path.splitext(os.path.basename(args.config))[0]
            config.folder = os.path.join(
                kge_base_dir(),
                "local",
                "experiments",
                datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name,
            )
        else:
            config.folder = args.folder

    # catch errors to log them
    try:
        if args.command == "start" and not config.init_folder():
            raise ValueError("output folder {} exists already".format(config.folder))
        config.log("Using folder: {}".format(config.folder))

        # determine checkpoint to resume (if any)
        if hasattr(args, "checkpoint"):
            checkpoint_file = get_checkpoint_file(config, args.checkpoint)

        # disable processing of outdated cached dataset files globally
        Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated

        # log configuration
        config.log("Configuration:")
        config.log(yaml.dump(config.options), prefix="  ")
        config.log("git commit: {}".format(get_git_revision_short_hash()), prefix="  ")

        # set random seeds
        def get_seed(what):
            seed = config.get(f"random_seed.{what}")
            if seed < 0 and config.get(f"random_seed.default") >= 0:
                import hashlib

                # we add an md5 hash to the default seed so that different PRNGs get a
                # different seed
                seed = (
                    config.get(f"random_seed.default")
                    + int(hashlib.md5(what.encode()).hexdigest(), 16)
                ) % 0xFFFF  # stay 32-bit

            return seed

        if get_seed("python") > -1:
            import random

            random.seed(get_seed("python"))
        if get_seed("torch") > -1:
            import torch

            torch.manual_seed(get_seed("torch"))
        if get_seed("numpy") > -1:
            import numpy.random

            numpy.random.seed(get_seed("numpy"))
        if get_seed("numba") > -1:
            import numpy as np, numba

            @numba.njit
            def seed_numba(seed):
                np.random.seed(seed)

            seed_numba(get_seed("numba"))

        # let's go
        if args.command == "start" and not args.run:
            config.log("Job created successfully.")
        else:
            # load data
            dataset = Dataset.create(config)

            # let's go
            if args.command == "resume":
                if checkpoint_file is not None:
                    checkpoint = load_checkpoint(
                        checkpoint_file, config.get("job.device")
                    )
                    job = Job.create_from(
                        checkpoint, new_config=config, dataset=dataset
                    )
                else:
                    job = Job.create(config, dataset)
                    job.config.log(
                        "No checkpoint found or specified, starting from scratch..."
                    )
            else:
                job = Job.create(config, dataset)
            job.run()
    except BaseException:
        tb = traceback.format_exc()
        config.log(tb, echo=False)
        raise
Esempio n. 8
0
def train(data_path,
          neg_batch_size,
          batch_size,
          shuffle,
          num_workers,
          nb_epochs,
          embedding_dim,
          hidden_dim,
          relation_dim,
          gpu,
          use_cuda,
          patience,
          freeze,
          validate_every,
          hops,
          lr,
          entdrop,
          reldrop,
          scoredrop,
          l3_reg,
          model_name,
          decay,
          ls,
          load_from,
          outfile,
          do_batch_norm,
          valid_data_path=None):
    print('Loading entities and relations')
    kg_type = 'full'
    if 'half' in hops:
        kg_type = 'half'
    checkpoint_file = '../../pretrained_models/embeddings/ComplEx_fbwq_' + kg_type + '/checkpoint_best.pt'
    print('Loading kg embeddings from', checkpoint_file)
    kge_checkpoint = load_checkpoint(checkpoint_file)
    kge_model = KgeModel.create_from(kge_checkpoint)
    kge_model.eval()
    e = getEntityEmbeddings(kge_model, hops)

    print('Loaded entities and relations')

    entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e)
    data = process_text_file(data_path, split=False)
    print('Train file processed, making dataloader')
    # word2ix,idx2word, max_len = get_vocab(data)
    # hops = str(num_hops)
    device = torch.device(gpu if use_cuda else "cpu")
    dataset = DatasetMetaQA(data, e, entity2idx)
    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=num_workers)
    print('Creating model...')
    model = RelationExtractor(embedding_dim=embedding_dim,
                              num_entities=len(idx2entity),
                              relation_dim=relation_dim,
                              pretrained_embeddings=embedding_matrix,
                              freeze=freeze,
                              device=device,
                              entdrop=entdrop,
                              reldrop=reldrop,
                              scoredrop=scoredrop,
                              l3_reg=l3_reg,
                              model=model_name,
                              ls=ls,
                              do_batch_norm=do_batch_norm)
    print('Model created!')
    if load_from != '':
        # model.load_state_dict(torch.load("checkpoints/roberta_finetune/" + load_from + ".pt"))
        fname = "checkpoints/roberta_finetune/" + load_from + ".pt"
        model.load_state_dict(
            torch.load(fname, map_location=lambda storage, loc: storage))
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ExponentialLR(optimizer, decay)
    optimizer.zero_grad()
    best_score = -float("inf")
    best_model = model.state_dict()
    no_update = 0
    # time.sleep(10)
    for epoch in range(nb_epochs):
        phases = []
        for i in range(validate_every):
            phases.append('train')
        phases.append('valid')
        for phase in phases:
            if phase == 'train':
                model.train()
                # model.apply(set_bn_eval)
                loader = tqdm(data_loader,
                              total=len(data_loader),
                              unit="batches")
                running_loss = 0
                for i_batch, a in enumerate(loader):
                    model.zero_grad()
                    question_tokenized = a[0].to(device)
                    attention_mask = a[1].to(device)
                    positive_head = a[2].to(device)
                    positive_tail = a[3].to(device)
                    loss = model(question_tokenized=question_tokenized,
                                 attention_mask=attention_mask,
                                 p_head=positive_head,
                                 p_tail=positive_tail)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    loader.set_postfix(Loss=running_loss /
                                       ((i_batch + 1) * batch_size),
                                       Epoch=epoch)
                    loader.set_description('{}/{}'.format(epoch, nb_epochs))
                    loader.update()

                scheduler.step()

            elif phase == 'valid':
                model.eval()
                eps = 0.0001
                answers, score = validate_v2(model=model,
                                             data_path=valid_data_path,
                                             entity2idx=entity2idx,
                                             train_dataloader=dataset,
                                             device=device,
                                             model_name=model_name)
                if score > best_score + eps:
                    best_score = score
                    no_update = 0
                    best_model = model.state_dict()
                    print(
                        hops +
                        " hop Validation accuracy (no relation scoring) increased from previous epoch",
                        score)
                    # writeToFile(answers, 'results_' + model_name + '_' + hops + '.txt')
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                elif (score < best_score + eps) and (no_update < patience):
                    no_update += 1
                    print(
                        "Validation accuracy decreases to %f from %f, %d more epoch to check"
                        % (score, best_score, patience - no_update))
                elif no_update == patience:
                    print(
                        "Model has exceed patience. Saving best model and exiting"
                    )
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                    exit()
                if epoch == nb_epochs - 1:
                    print("Final Epoch has reached. Stoping and saving model.")
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                    exit()
Esempio n. 9
0
def perform_experiment(data_path,
                       mode,
                       neg_batch_size,
                       batch_size,
                       shuffle,
                       num_workers,
                       nb_epochs,
                       embedding_dim,
                       hidden_dim,
                       relation_dim,
                       gpu,
                       use_cuda,
                       patience,
                       freeze,
                       validate_every,
                       hops,
                       lr,
                       entdrop,
                       reldrop,
                       scoredrop,
                       l3_reg,
                       model_name,
                       decay,
                       ls,
                       load_from,
                       outfile,
                       do_batch_norm,
                       que_embedding_model,
                       valid_data_path=None,
                       test_data_path=None):
    webqsp_checkpoint_folder = f"../../checkpoints/WebQSP/{model_name}_{que_embedding_model}_{outfile}/"
    if not os.path.exists(webqsp_checkpoint_folder):
        os.makedirs(webqsp_checkpoint_folder)

    print('Loading entities and relations')
    kg_type = 'full'
    if 'half' in hops:
        kg_type = 'half'

    checkpoint_file = f"../../pretrained_models/embeddings/{model_name}_fbwq_{kg_type}/checkpoint_best.pt"

    print('Loading kg embeddings from', checkpoint_file)
    kge_checkpoint = load_checkpoint(checkpoint_file)
    kge_model = KgeModel.create_from(kge_checkpoint)
    kge_model.eval()
    e = getEntityEmbeddings(model_name, kge_model, hops)

    print('Loaded entities and relations')

    entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e)

    # word2ix,idx2word, max_len = get_vocab(data)
    # hops = str(num_hops)
    device = torch.device(gpu if use_cuda else "cpu")
    model = RelationExtractor(embedding_dim=embedding_dim,
                              num_entities=len(idx2entity),
                              relation_dim=relation_dim,
                              pretrained_embeddings=embedding_matrix,
                              freeze=freeze,
                              device=device,
                              entdrop=entdrop,
                              reldrop=reldrop,
                              scoredrop=scoredrop,
                              l3_reg=l3_reg,
                              model=model_name,
                              que_embedding_model=que_embedding_model,
                              ls=ls,
                              do_batch_norm=do_batch_norm)

    # time.sleep(10)
    if mode == 'train':
        data = process_text_file(data_path)
        dataset = DatasetWebQSP(data, e, entity2idx, que_embedding_model,
                                model_name)

        # if model_name=="ComplEx":
        #     data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        # else:
        #     data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=custom_collate_fn)

        data_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=num_workers)

        if load_from != '':
            # model.load_state_dict(torch.load("checkpoints/roberta_finetune/" + load_from + ".pt"))
            fname = f"checkpoints/{que_embedding_model}_finetune/{load_from}.pt"
            model.load_state_dict(
                torch.load(fname, map_location=lambda storage, loc: storage))
        model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        scheduler = ExponentialLR(optimizer, decay)
        optimizer.zero_grad()
        best_score = -float("inf")
        best_model = model.state_dict()
        no_update = 0
        for epoch in range(nb_epochs):
            phases = []
            for i in range(validate_every):
                phases.append('train')
            phases.append('valid')
            for phase in phases:
                if phase == 'train':
                    model.train()
                    # model.apply(set_bn_eval)
                    loader = tqdm(data_loader,
                                  total=len(data_loader),
                                  unit="batches")
                    running_loss = 0
                    for i_batch, a in enumerate(loader):
                        model.zero_grad()
                        question_tokenized = a[0].to(device)
                        attention_mask = a[1].to(device)
                        positive_head = a[2].to(device)
                        positive_tail = a[3].to(device)
                        loss = model(question_tokenized=question_tokenized,
                                     attention_mask=attention_mask,
                                     p_head=positive_head,
                                     p_tail=positive_tail)
                        loss.backward()
                        optimizer.step()
                        running_loss += loss.item()
                        loader.set_postfix(Loss=running_loss /
                                           ((i_batch + 1) * batch_size),
                                           Epoch=epoch)
                        loader.set_description('{}/{}'.format(
                            epoch, nb_epochs))
                        loader.update()

                    scheduler.step()

                elif phase == 'valid':
                    model.eval()
                    eps = 0.0001
                    answers, score = test(model=model,
                                          data_path=valid_data_path,
                                          entity2idx=entity2idx,
                                          dataloader=dataset,
                                          device=device,
                                          model_name=model_name,
                                          return_hits_at_k=False)
                    if score > best_score + eps:
                        best_score = score
                        no_update = 0
                        best_model = model.state_dict()
                        print(
                            hops +
                            " hop Validation accuracy (no relation scoring) increased from previous epoch",
                            score)
                        writeToFile(
                            answers,
                            f'results/{model_name}_{que_embedding_model}_{outfile}.txt'
                        )
                        torch.save(
                            best_model,
                            get_chkpt_path(model_name, que_embedding_model,
                                           outfile))
                    elif (score < best_score + eps) and (no_update < patience):
                        no_update += 1
                        print(
                            "Validation accuracy decreases to %f from %f, %d more epoch to check"
                            % (score, best_score, patience - no_update))
                    elif no_update == patience:
                        print(
                            "Model has exceed patience. Saving best model and exiting"
                        )
                        torch.save(
                            best_model,
                            get_chkpt_path(model_name, que_embedding_model,
                                           outfile))
                        exit(0)
                    if epoch == nb_epochs - 1:
                        print(
                            "Final Epoch has reached. Stoping and saving model."
                        )
                        torch.save(
                            best_model,
                            get_chkpt_path(model_name, que_embedding_model,
                                           outfile))
                        exit()
                    # torch.save(model.state_dict(), "checkpoints/roberta_finetune/"+str(epoch)+".pt")
                    # torch.save(model.state_dict(), "checkpoints/roberta_finetune/x.pt")

    elif mode == 'test':
        data = process_text_file(test_data_path)
        dataset = DatasetWebQSP(data, e, entity2idx, que_embedding_model,
                                model_name)
        model_chkpt_file_path = get_chkpt_path(model_name, que_embedding_model,
                                               outfile)
        model.load_state_dict(
            torch.load(model_chkpt_file_path,
                       map_location=lambda storage, loc: storage))
        model.to(device)
        for parameter in model.parameters():
            parameter.requires_grad = False
        model.eval()
        answers, accuracy, hits_at_1, hits_at_5, hits_at_10 = test(
            model=model,
            data_path=test_data_path,
            entity2idx=entity2idx,
            dataloader=dataset,
            device=device,
            model_name=model_name,
            return_hits_at_k=True)

        d = {
            'KG-Model': model_name,
            'KG-Type': kg_type,
            'Que-Embedding-Model': que_embedding_model,
            'Accuracy': [accuracy],
            'Hits@1': [hits_at_1],
            'Hits@5': [hits_at_5],
            'Hits@10': [hits_at_10]
        }
        df = pd.DataFrame(data=d)
        df.to_csv(f"final_results.csv", mode='a', index=False, header=False)