示例#1
0
    def __init__(self):
        super(BlinkReader, self).__init__(filename="data/blink.pickle")

        self.model_loaded = False
        self.evaluator = None
        self.biencoder = None
        self.biencoder_params = None
        self.candidate_encoding = None
        self.faiss_indexer = None
        self.top_k = 1
        self.logger = utils.get_logger('output')
示例#2
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(output_path)

    logger.info("Loading candidate encoding from path: %s" %
                params["candidate_encoding"])
    candidate_encoding = torch.load(params["candidate_encoding"])
    vector_size = candidate_encoding.size(1)
    index_buffer = params["index_buffer"]
    if params["hnsw"]:
        logger.info("Using HNSW index in FAISS")
        index = DenseHNSWFlatIndexer(vector_size, index_buffer)
    else:
        logger.info("Using Flat index in FAISS")
        index = DenseFlatIndexer(vector_size, index_buffer)

    logger.info("Building index.")
    index.index_data(candidate_encoding.numpy())
    logger.info("Done indexing data.")

    if params.get("save_index", None):
        index.serialize(output_path)
示例#3
0
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import logging
import numpy
import os
import time
import torch

from blink.index.faiss_indexer import DenseFlatIndexer, DenseHNSWFlatIndexer
import blink.candidate_ranking.utils as utils

logger = utils.get_logger()


def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(output_path)

    logger.info("Loading candidate encoding from path: %s" %
                params["candidate_encoding"])
    candidate_encoding = torch.load(params["candidate_encoding"])
    vector_size = candidate_encoding.size(1)
    index_buffer = params["index_buffer"]
    if params["hnsw"]:
        logger.info("Using HNSW index in FAISS")
示例#4
0
def main(params):
    model_output_path = params["output_path"]
    if not os.path.exists(model_output_path):
        os.makedirs(model_output_path)
    pickle_src_path = params["pickle_src_path"]
    logger = utils.get_logger(params["output_path"])

    # Init model
    reranker = CrossEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model
    device = reranker.device
    n_gpu = reranker.n_gpu

    if params["gradient_accumulation_steps"] < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                params["gradient_accumulation_steps"]
            )
        )
    # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
    # args.gradient_accumulation_steps = args.gradient_accumulation_steps // n_gpu
    params["train_batch_size"] = (
        params["train_batch_size"] // params["gradient_accumulation_steps"]
    )
    grad_acc_steps = params["gradient_accumulation_steps"]

    # Fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    max_seq_length = params["max_seq_length"]
    context_length = params["max_context_length"]
    candidate_length = params["max_context_length"]

    if params["only_evaluate"]:
        test_dataloader, n_test_skipped, data = get_data_loader("test", tokenizer, context_length, candidate_length,
                                                                max_seq_length, pickle_src_path, logger,
                                                                inject_ground_truth=params["inject_eval_ground_truth"],
                                                                shuffle=False, return_data=True,
                                                                custom_cand_set=params["custom_cand_set"])
        logger.info("Evaluating the model on the test set")
        results = evaluate(
            reranker,
            test_dataloader,
            device=device,
            logger=logger,
            context_length=context_length,
            silent=params["silent"],
            unfiltered_length=len(data["mention_data"]),
            mention_data=data,
            compute_macro_avg=True,
            store_failure_success=True
        )
        results_path = os.path.join(model_output_path, 'results.json')
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2)
            print(f"\nAnalysis saved at: {results_path}")
        exit()

    train_dataloader, _, train_data = get_data_loader('train', tokenizer, context_length, candidate_length,
                                                      max_seq_length, pickle_src_path, logger,
                                                      inject_ground_truth=params["inject_train_ground_truth"],
                                                      return_data=True)

    valid_dataloader, n_valid_skipped = get_data_loader('valid', tokenizer, context_length, candidate_length,
                                                        max_seq_length, pickle_src_path, logger,
                                                        inject_ground_truth=params["inject_eval_ground_truth"],
                                                        max_n=2048)

    if not params["skip_initial_eval"]:
        logger.info("Evaluating dev set on untrained model...")
        # Evaluate before training
        results = evaluate(
            reranker,
            valid_dataloader,
            device=device,
            logger=logger,
            context_length=context_length,
            silent=params["silent"],
        )

    time_start = time.time()

    utils.write_to_file(
        os.path.join(model_output_path, "training_params.txt"), str(params)
    )

    logger.info("Starting training")
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}".format(device, n_gpu, False)
    )

    optimizer = get_optimizer(model, params)
    scheduler = get_scheduler(params, optimizer, len(train_data["mention_data"]), logger)

    model.train()

    best_epoch_idx = -1
    best_score = -1

    num_train_epochs = params["num_train_epochs"]

    for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):
        tr_loss = 0
        results = None

        if params["silent"]:
            iter_ = train_dataloader
        else:
            iter_ = tqdm(train_dataloader, desc="Batch")

        part = 0
        for step, batch in enumerate(iter_):
            batch = tuple(t.to(device) for t in batch)
            context_input, label_input, _ = batch
            loss, _ = reranker(context_input, label_input, context_length)

            if grad_acc_steps > 1:
                loss = loss / grad_acc_steps

            tr_loss += loss.item()

            if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
                logger.info(
                    "Step {} - epoch {} average loss: {}\n".format(
                        step,
                        epoch_idx,
                        tr_loss / (params["print_interval"] * grad_acc_steps),
                    )
                )
                tr_loss = 0

            loss.backward()

            if (step + 1) % grad_acc_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), params["max_grad_norm"]
                )
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if params["eval_interval"] != -1:
                if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
                    logger.info("Evaluation on the development dataset")
                    evaluate(
                        reranker,
                        valid_dataloader,
                        device=device,
                        logger=logger,
                        context_length=context_length,
                        silent=params["silent"],
                    )
                    logger.info("***** Saving fine - tuned model *****")
                    epoch_output_folder_path = os.path.join(
                        model_output_path, "epoch_{}_{}".format(epoch_idx, part)
                    )
                    part += 1
                    utils.save_model(model, tokenizer, epoch_output_folder_path)
                    model.train()
                    logger.info("\n")

        logger.info("***** Saving fine - tuned model *****")
        epoch_output_folder_path = os.path.join(
            model_output_path, "epoch_{}".format(epoch_idx)
        )
        utils.save_model(model, tokenizer, epoch_output_folder_path)

        results = evaluate(
            reranker,
            valid_dataloader,
            device=device,
            logger=logger,
            context_length=context_length,
            silent=params["silent"],
        )

        ls = [best_score, results["normalized_accuracy"]]
        li = [best_epoch_idx, epoch_idx]

        best_score = ls[np.argmax(ls)]
        best_epoch_idx = li[np.argmax(ls)]
        logger.info("\n")

    execution_time = (time.time() - time_start) / 60
    utils.write_to_file(
        os.path.join(model_output_path, "training_time.txt"),
        "The training took {} minutes\n".format(execution_time),
    )
    logger.info("The training took {} minutes\n".format(execution_time))

    # save the best model in the parent_dir
    logger.info("Best performance in epoch: {}".format(best_epoch_idx))
示例#5
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log-eval')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    cands_path = os.path.join(output_path, 'cands.pickle')

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    use_types = params["use_types"]
    within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    test_samples = None
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    if params['transductive']:
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_mention_data_pkl_path = os.path.join(
            pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(
                filter(
                    lambda sample: (len(sample["labels"]) > 0)
                    if mult_labels else (sample["label"] is not None),
                    test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            entity_dictionary_loaded = True
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(list(map(lambda x: x['ids'],
                                           test_dictionary)),
                                  dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    if within_doc:
        if test_samples is None:
            test_samples, _ = read_data(data_split, params, logger)
        test_context_doc_ids = [s['context_doc_id'] for s in test_samples]

    if params["transductive"]:
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_mention_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_mention_data_pkl_path, 'rb') as read_handle:
                train_mention_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset('train', params["data_path"])

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            logger.info("Read %d test samples." % len(test_samples))

            train_mention_data, _, train_tensor_data = data_process.process_mention_data(
                train_samples,
                test_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                multi_label_key="labels" if mult_labels else None,
                context_key=params["context_key"],
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_mention_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_mention_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store train mention token ids
        train_men_vecs = train_tensor_data[:][0]
        n_mentions += len(train_tensor_data)
        n_train_mentions = len(train_tensor_data)

    # if os.path.isfile(cands_path):
    #     print("Loading stored candidates...")
    #     with open(cands_path, 'rb') as read_handle:
    #         men_cands = pickle.load(read_handle)
    # else:
    # Check and load stored embedding data
    embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
    embed_data = None
    if os.path.isfile(embed_data_path):
        embed_data = torch.load(embed_data_path)

    if use_types:
        if embed_data is not None:
            logger.info('Loading stored embeddings and computing indexes')
            dict_embeds = embed_data['dict_embeds']
            if 'dict_idxs_by_type' in embed_data:
                dict_idxs_by_type = embed_data['dict_idxs_by_type']
            else:
                dict_idxs_by_type = data_process.get_idxs_by_type(
                    test_dictionary)
            dict_indexes = data_process.get_index_from_embeds(
                dict_embeds,
                dict_idxs_by_type,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
            men_embeds = embed_data['men_embeds']
            if 'men_idxs_by_type' in embed_data:
                men_idxs_by_type = embed_data['men_idxs_by_type']
            else:
                men_idxs_by_type = data_process.get_idxs_by_type(mention_data)
            men_indexes = data_process.get_index_from_embeds(
                men_embeds,
                men_idxs_by_type,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
        else:
            logger.info("Dictionary: Embedding and building index")
            dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                reranker,
                test_dict_vecs,
                encoder_type="candidate",
                n_gpu=n_gpu,
                corpus=test_dictionary,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
            logger.info("Queries: Embedding and building index")
            vecs = test_men_vecs
            men_data = mention_data
            if params['transductive']:
                vecs = torch.cat((train_men_vecs, vecs), dim=0)
                men_data = train_mention_data + mention_data
            men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
                reranker,
                vecs,
                encoder_type="context",
                n_gpu=n_gpu,
                corpus=men_data,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
    else:
        if embed_data is not None:
            logger.info('Loading stored embeddings and computing indexes')
            dict_embeds = embed_data['dict_embeds']
            dict_index = data_process.get_index_from_embeds(
                dict_embeds,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
            men_embeds = embed_data['men_embeds']
            men_index = data_process.get_index_from_embeds(
                men_embeds,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
        else:
            logger.info("Dictionary: Embedding and building index")
            dict_embeds, dict_index = data_process.embed_and_index(
                reranker,
                test_dict_vecs,
                'candidate',
                n_gpu=n_gpu,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
            logger.info("Queries: Embedding and building index")
            vecs = test_men_vecs
            if params['transductive']:
                vecs = torch.cat((train_men_vecs, vecs), dim=0)
            men_embeds, men_index = data_process.embed_and_index(
                reranker,
                vecs,
                'context',
                n_gpu=n_gpu,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])

    # Save computed embedding data if not loaded from disk
    if embed_data is None:
        embed_data = {}
        embed_data['dict_embeds'] = dict_embeds
        embed_data['men_embeds'] = men_embeds
        if use_types:
            embed_data['dict_idxs_by_type'] = dict_idxs_by_type
            embed_data['men_idxs_by_type'] = men_idxs_by_type
        # NOTE: Cannot pickle faiss index because it is a SwigPyObject
        torch.save(embed_data,
                   embed_data_path,
                   pickle_protocol=pickle.HIGHEST_PROTOCOL)

    logger.info("Starting KNN search...")
    # Fetch (k+1) NN mention candidates
    if not use_types:
        _men_embeds = men_embeds
        if params['transductive']:
            _men_embeds = _men_embeds[n_train_mentions:]
        n_mens_to_fetch = len(_men_embeds) if within_doc else knn + 1
        nn_men_dists, nn_men_idxs = men_index.search(_men_embeds,
                                                     n_mens_to_fetch)
    else:
        query_len = len(men_embeds) - (n_train_mentions
                                       if params['transductive'] else 0)
        nn_men_idxs = -1 * np.ones((query_len, query_len), dtype=int)
        nn_men_dists = -1 * np.ones((query_len, query_len), dtype='float64')
        for entity_type in men_indexes:
            men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type][
                men_idxs_by_type[entity_type] >=
                n_train_mentions]] if params['transductive'] else men_embeds[
                    men_idxs_by_type[entity_type]]
            n_mens_to_fetch = len(
                men_embeds_by_type) if within_doc else knn + 1
            nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[
                entity_type].search(
                    men_embeds_by_type,
                    min(n_mens_to_fetch, len(men_embeds_by_type)))
            nn_men_idxs_by_type = np.array(
                list(
                    map(lambda x: men_idxs_by_type[entity_type][x],
                        nn_men_idxs_by_type)))
            i = -1
            for idx in men_idxs_by_type[entity_type]:
                if params['transductive']:
                    idx -= n_train_mentions
                if idx < 0:
                    continue
                i += 1
                nn_men_idxs[idx][:len(nn_men_idxs_by_type[i]
                                      )] = nn_men_idxs_by_type[i]
                nn_men_dists[idx][:len(nn_men_dists_by_type[i]
                                       )] = nn_men_dists_by_type[i]
    logger.info("Search finished")

    logger.info(f'Calculating mention recall@{knn}')
    # Find the most similar entity and k-nn mentions for each mention query
    men_recall_knn = []
    men_cands = []
    cui_sums = defaultdict(int)
    for m in mention_data:
        cui_sums[m['label_cuis'][0]] += 1
    for idx in range(len(nn_men_idxs)):
        filter_mask_neg1 = nn_men_idxs[idx] != -1
        men_cand_idxs = nn_men_idxs[idx][filter_mask_neg1]
        men_cand_scores = nn_men_dists[idx][filter_mask_neg1]

        if within_doc:
            men_cand_idxs, wd_mask = filter_by_context_doc_id(
                men_cand_idxs,
                test_context_doc_ids[idx],
                test_context_doc_ids,
                return_numpy=True)
            men_cand_scores = men_cand_scores[wd_mask]

        # Filter candidates to remove mention query and keep only the top k candidates
        filter_mask = men_cand_idxs != idx
        men_cand_idxs, men_cand_scores = men_cand_idxs[
            filter_mask][:knn], men_cand_scores[filter_mask][:knn]
        assert len(men_cand_idxs) == knn
        men_cands.append(men_cand_idxs)
        # Calculate mention recall@k
        gold_cui = mention_data[idx]['label_cuis'][0]
        if cui_sums[gold_cui] > 1:
            recall_hit = [
                1. for midx in men_cand_idxs
                if mention_data[midx]['label_cuis'][0] == gold_cui
            ]
            recall_knn = sum(recall_hit) / min(cui_sums[gold_cui] - 1, knn)
            men_recall_knn.append(recall_knn)
    logger.info('Done')
    # assert len(men_recall_knn) == len(nn_men_idxs)

    # Pickle the graphs
    print(f"Saving top-{knn} candidates")
    with open(cands_path, 'wb') as write_handle:
        pickle.dump(men_cands, write_handle, protocol=pickle.HIGHEST_PROTOCOL)

    # Output final results
    mean_recall_knn = np.mean(men_recall_knn)
    logger.info(f"Result: Final mean recall@{knn} = {mean_recall_knn}")
示例#6
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"])

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    device = reranker.device

    cand_encode_path = params.get("cand_encode_path", None)

    # candidate encoding is not pre-computed.
    # load/generate candidate pool to compute candidate encoding.
    cand_pool_path = params.get("cand_pool_path", None)
    candidate_pool = load_or_generate_candidate_pool(
        tokenizer,
        params,
        logger,
        cand_pool_path,
    )

    candidate_encoding = None
    if cand_encode_path is not None:
        # try to load candidate encoding from path
        # if success, avoid computing candidate encoding
        try:
            logger.info("Loading pre-generated candidate encode path.")
            candidate_encoding = torch.load(cand_encode_path)
        except:
            logger.info("Loading failed. Generating candidate encoding.")

    if candidate_encoding is None:
        candidate_encoding = encode_candidate_zeshel(
            reranker,
            candidate_pool,
            params["encode_batch_size"],
            silent=params["silent"],
            logger=logger,
        )

        if cand_encode_path is not None:
            # Save candidate encoding to avoid re-compute
            logger.info("Saving candidate encoding to file " +
                        cand_encode_path)
            torch.save(cand_encode_path, candidate_encoding)

    test_samples = utils.read_dataset(params["mode"], params["data_path"])
    logger.info("Read %d test samples." % len(test_samples))

    test_data, test_tensor_data = data.process_mention_data(
        test_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params['context_key'],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )
    test_sampler = SequentialSampler(test_tensor_data)
    test_dataloader = DataLoader(test_tensor_data,
                                 sampler=test_sampler,
                                 batch_size=params["encode_batch_size"])

    save_results = params.get("save_topk_result")
    new_data = nnquery.get_topk_predictions(
        reranker,
        test_dataloader,
        candidate_pool,
        candidate_encoding,
        params["silent"],
        logger,
        params["top_k"],
        params.get("zeshel", None),
        save_results,
    )

    if save_results:
        save_data_path = os.path.join(
            params['output_path'],
            'candidates_%s_top%d.t7' % (params['mode'], params['top_k']))
        torch.save(new_data, save_data_path)
示例#7
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log-eval')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    use_types = params["use_types"]
    within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    test_samples = None
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    if params['transductive']:
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_mention_data_pkl_path = os.path.join(
            pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(
                filter(
                    lambda sample: (len(sample["labels"]) > 0)
                    if mult_labels else (sample["label"] is not None),
                    test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            entity_dictionary_loaded = True
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(list(map(lambda x: x['ids'],
                                           test_dictionary)),
                                  dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    if within_doc:
        if test_samples is None:
            test_samples, _ = read_data(data_split, params, logger)
        test_context_doc_ids = [s['context_doc_id'] for s in test_samples]

    if params["transductive"]:
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_mention_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_mention_data_pkl_path, 'rb') as read_handle:
                train_mention_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset('train', params["data_path"])

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            logger.info("Read %d test samples." % len(test_samples))

            train_mention_data, _, train_tensor_data = data_process.process_mention_data(
                train_samples,
                test_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                multi_label_key="labels" if mult_labels else None,
                context_key=params["context_key"],
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_mention_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_mention_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store train mention token ids
        train_men_vecs = train_tensor_data[:][0]
        n_mentions += len(train_tensor_data)
        n_train_mentions = len(train_tensor_data)

    # Values of k to run the evaluation against
    knn_vals = [0] + [2**i for i in range(int(math.log(knn, 2)) + 1)]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    time_start = time.time()

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if not params['only_recall'] and os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_entities + n_mentions, n_entities + n_mentions)
            }

        # Check and load stored embedding data
        embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
        embed_data = None
        if os.path.isfile(embed_data_path):
            embed_data = torch.load(embed_data_path)

        if use_types:
            if embed_data is not None:
                logger.info('Loading stored embeddings and computing indexes')
                dict_embeds = embed_data['dict_embeds']
                if 'dict_idxs_by_type' in embed_data:
                    dict_idxs_by_type = embed_data['dict_idxs_by_type']
                else:
                    dict_idxs_by_type = data_process.get_idxs_by_type(
                        test_dictionary)
                dict_indexes = data_process.get_index_from_embeds(
                    dict_embeds,
                    dict_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                men_embeds = embed_data['men_embeds']
                if 'men_idxs_by_type' in embed_data:
                    men_idxs_by_type = embed_data['men_idxs_by_type']
                else:
                    men_idxs_by_type = data_process.get_idxs_by_type(
                        mention_data)
                men_indexes = data_process.get_index_from_embeds(
                    men_embeds,
                    men_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info("Dictionary: Embedding and building index")
                dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    test_dict_vecs,
                    encoder_type="candidate",
                    n_gpu=n_gpu,
                    corpus=test_dictionary,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
                logger.info("Queries: Embedding and building index")
                vecs = test_men_vecs
                men_data = mention_data
                if params['transductive']:
                    vecs = torch.cat((train_men_vecs, vecs), dim=0)
                    men_data = train_mention_data + mention_data
                men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    corpus=men_data,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
        else:
            if embed_data is not None:
                logger.info('Loading stored embeddings and computing indexes')
                dict_embeds = embed_data['dict_embeds']
                dict_index = data_process.get_index_from_embeds(
                    dict_embeds,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                men_embeds = embed_data['men_embeds']
                men_index = data_process.get_index_from_embeds(
                    men_embeds,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info("Dictionary: Embedding and building index")
                dict_embeds, dict_index = data_process.embed_and_index(
                    reranker,
                    test_dict_vecs,
                    'candidate',
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
                logger.info("Queries: Embedding and building index")
                vecs = test_men_vecs
                if params['transductive']:
                    vecs = torch.cat((train_men_vecs, vecs), dim=0)
                men_embeds, men_index = data_process.embed_and_index(
                    reranker,
                    vecs,
                    'context',
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])

        # Save computed embedding data if not loaded from disk
        if embed_data is None:
            embed_data = {}
            embed_data['dict_embeds'] = dict_embeds
            embed_data['men_embeds'] = men_embeds
            if use_types:
                embed_data['dict_idxs_by_type'] = dict_idxs_by_type
                embed_data['men_idxs_by_type'] = men_idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(embed_data,
                       embed_data_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        recall_accuracy = {
            2**i: 0
            for i in range(int(math.log(params['recall_k'], 2)) + 1)
        }
        recall_idxs = [0.] * params['recall_k']

        logger.info("Starting KNN search...")
        # Fetch recall_k (default 16) knn entities for all mentions
        # Fetch (k+1) NN mention candidates
        if not use_types:
            _men_embeds = men_embeds
            if params['transductive']:
                _men_embeds = _men_embeds[n_train_mentions:]
            nn_ent_dists, nn_ent_idxs = dict_index.search(
                _men_embeds, params['recall_k'])
            n_mens_to_fetch = len(_men_embeds) if within_doc else max_knn + 1
            nn_men_dists, nn_men_idxs = men_index.search(
                _men_embeds, n_mens_to_fetch)
        else:
            query_len = len(men_embeds) - (n_train_mentions
                                           if params['transductive'] else 0)
            nn_ent_idxs = np.zeros((query_len, params['recall_k']))
            nn_ent_dists = np.zeros((query_len, params['recall_k']),
                                    dtype='float64')
            nn_men_idxs = -1 * np.ones((query_len, query_len), dtype=int)
            nn_men_dists = -1 * np.ones(
                (query_len, query_len), dtype='float64')
            for entity_type in men_indexes:
                men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type][
                    men_idxs_by_type[entity_type] >=
                    n_train_mentions]] if params[
                        'transductive'] else men_embeds[
                            men_idxs_by_type[entity_type]]
                nn_ent_dists_by_type, nn_ent_idxs_by_type = dict_indexes[
                    entity_type].search(men_embeds_by_type, params['recall_k'])
                nn_ent_idxs_by_type = np.array(
                    list(
                        map(lambda x: dict_idxs_by_type[entity_type][x],
                            nn_ent_idxs_by_type)))
                n_mens_to_fetch = len(
                    men_embeds_by_type) if within_doc else max_knn + 1
                nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[
                    entity_type].search(
                        men_embeds_by_type,
                        min(n_mens_to_fetch, len(men_embeds_by_type)))
                nn_men_idxs_by_type = np.array(
                    list(
                        map(lambda x: men_idxs_by_type[entity_type][x],
                            nn_men_idxs_by_type)))
                i = -1
                for idx in men_idxs_by_type[entity_type]:
                    if params['transductive']:
                        idx -= n_train_mentions
                    if idx < 0:
                        continue
                    i += 1
                    nn_ent_idxs[idx] = nn_ent_idxs_by_type[i]
                    nn_ent_dists[idx] = nn_ent_dists_by_type[i]
                    nn_men_idxs[idx][:len(nn_men_idxs_by_type[i]
                                          )] = nn_men_idxs_by_type[i]
                    nn_men_dists[idx][:len(nn_men_dists_by_type[i]
                                           )] = nn_men_dists_by_type[i]
        logger.info("Search finished")

        logger.info('Building graphs')
        # Find the most similar entity and k-nn mentions for each mention query
        for idx in range(len(nn_ent_idxs)):
            # Get nearest entity candidate
            dict_cand_idx = nn_ent_idxs[idx][0]
            dict_cand_score = nn_ent_dists[idx][0]
            # Compute recall metric
            gold_idxs = mention_data[idx][
                "label_idxs"][:mention_data[idx]["n_labels"]]
            recall_idx = np.argwhere(nn_ent_idxs[idx] == gold_idxs[0])
            if len(recall_idx) != 0:
                recall_idx = int(recall_idx)
                recall_idxs[recall_idx] += 1.
                for recall_k in recall_accuracy:
                    if recall_idx < recall_k:
                        recall_accuracy[recall_k] += 1.
            if not params['only_recall']:
                filter_mask_neg1 = nn_men_idxs[idx] != -1
                men_cand_idxs = nn_men_idxs[idx][filter_mask_neg1]
                men_cand_scores = nn_men_dists[idx][filter_mask_neg1]

                if within_doc:
                    men_cand_idxs, wd_mask = filter_by_context_doc_id(
                        men_cand_idxs,
                        test_context_doc_ids[idx],
                        test_context_doc_ids,
                        return_numpy=True)
                    men_cand_scores = men_cand_scores[wd_mask]

                # Filter candidates to remove mention query and keep only the top k candidates
                filter_mask = men_cand_idxs != idx
                men_cand_idxs, men_cand_scores = men_cand_idxs[
                    filter_mask][:max_knn], men_cand_scores[
                        filter_mask][:max_knn]

                if params['transductive']:
                    idx += n_train_mentions
                # Add edges to the graphs
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    # Add mention-entity edge
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'],
                        [n_entities + idx
                         ])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(joint_graph['cols'],
                                                    dict_cand_idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    dict_cand_score)
                    if k > 0:
                        # Add mention-mention edges
                        joint_graph['rows'] = np.append(
                            joint_graph['rows'],
                            [n_entities + idx] * len(men_cand_idxs[:k]))
                        joint_graph['cols'] = np.append(
                            joint_graph['cols'],
                            n_entities + men_cand_idxs[:k])
                        joint_graph['data'] = np.append(
                            joint_graph['data'], men_cand_scores[:k])

        if params['transductive']:
            # Add positive infinity mention-entity edges from training queries to labeled entities
            for idx, train_men in enumerate(train_mention_data):
                dict_cand_idx = train_men["label_idxs"][0]
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'],
                        [n_entities + idx
                         ])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(joint_graph['cols'],
                                                    dict_cand_idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    float('inf'))

        # Compute and print recall metric
        recall_idx_mode = np.argmax(recall_idxs)
        recall_idx_mode_prop = recall_idxs[recall_idx_mode] / np.sum(
            recall_idxs)
        logger.info(f"""
        Recall metrics (for {len(nn_ent_idxs)} queries):
        ---------------""")
        logger.info(
            f"highest recall idx = {recall_idx_mode} ({recall_idxs[recall_idx_mode]}/{np.sum(recall_idxs)} = {recall_idx_mode_prop})"
        )
        for recall_k in recall_accuracy:
            recall_accuracy[recall_k] /= len(nn_ent_idxs)
            logger.info(f"recall@{recall_k} = {recall_accuracy[recall_k]}")

        if params['only_recall']:
            exit()

        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

        if params['only_embed_and_build']:
            logger.info(f"Saved embedding data at: {embed_data_path}")
            logger.info(f"Saved graphs at: {graph_path}")
            exit()

    graph_mode = params.get('graph_mode', None)

    result_overview = {
        'n_entities':
        n_entities,
        'n_mentions':
        n_mentions - (n_train_mentions if params['transductive'] else 0)
    }
    results = {}
    if graph_mode is None or graph_mode not in ['directed', 'undirected']:
        results['directed'] = []
        results['undirected'] = []
    else:
        results[graph_mode] = []

    knn_fetch_time = time.time() - time_start
    graph_processing_time = time.time()
    n_graphs_processed = 0.

    for mode in results:
        print(f'\nEvaluation mode: {mode.upper()}')
        for k in joint_graphs:
            if k <= knn:
                print(f"\nGraph (k={k}):")
                # Partition graph based on cluster-linking constraints
                partitioned_graph, clusters = partition_graph(
                    joint_graphs[k],
                    n_entities,
                    mode == 'directed',
                    return_clusters=True)
                # Infer predictions from clusters
                result = analyzeClusters(
                    clusters, test_dictionary, mention_data, k,
                    n_train_mentions if params['transductive'] else 0)
                # Store result
                results[mode].append(result)
                n_graphs_processed += 1

    avg_graph_processing_time = (time.time() -
                                 graph_processing_time) / n_graphs_processed
    avg_per_graph_time = (knn_fetch_time + avg_graph_processing_time) / 60

    execution_time = (time.time() - time_start) / 60
    # Store results
    output_file_name = os.path.join(
        output_path,
        f"eval_results_{__import__('calendar').timegm(__import__('time').gmtime())}"
    )

    try:
        for recall_k in recall_accuracy:
            result_overview[f'recall@{recall_k}'] = recall_accuracy[recall_k]
    except:
        logger.info(
            "Recall data not available since graphs were loaded from disk")

    for mode in results:
        mode_results = results[mode]
        result_overview[mode] = {}
        for r in mode_results:
            k = r['knn_mentions']
            result_overview[mode][f'accuracy@knn{k}'] = r['accuracy']
            logger.info(f"{mode} accuracy@knn{k} = {r['accuracy']}")
            output_file = f'{output_file_name}-{mode}-{k}.json'
            with open(output_file, 'w') as f:
                json.dump(r, f, indent=2)
                print(
                    f"\nPredictions ({mode}) @knn{k} saved at: {output_file}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(result_overview, f, indent=2)
        print(f"\nPredictions overview saved at: {output_file_name}.json")

    logger.info("\nThe avg. per graph evaluation time is {} minutes\n".format(
        avg_per_graph_time))
    logger.info(
        "\nThe total evaluation took {} minutes\n".format(execution_time))
示例#8
0
def main(params):
    model_output_path = params["output_path"]
    if not os.path.exists(model_output_path):
        os.makedirs(model_output_path)
    logger = utils.get_logger(params["output_path"])

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    # utils.save_model(model, tokenizer, model_output_path)

    device = reranker.device
    n_gpu = reranker.n_gpu

    if params["gradient_accumulation_steps"] < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(params["gradient_accumulation_steps"]))

    # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
    # args.gradient_accumulation_steps = args.gradient_accumulation_steps // n_gpu
    params["train_batch_size"] = (params["train_batch_size"] //
                                  params["gradient_accumulation_steps"])
    train_batch_size = params["train_batch_size"]
    eval_batch_size = params["eval_batch_size"]
    grad_acc_steps = params["gradient_accumulation_steps"]

    # Fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    # Load train data
    train_samples = utils.read_dataset("train", params["data_path"])
    logger.info("Read %d train samples." % len(train_samples))

    train_data, train_tensor_data = data.process_mention_data(
        train_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params["context_key"],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )
    if params["shuffle"]:
        train_sampler = RandomSampler(train_tensor_data)
    else:
        train_sampler = SequentialSampler(train_tensor_data)

    train_dataloader = DataLoader(train_tensor_data,
                                  sampler=train_sampler,
                                  batch_size=train_batch_size)

    # Load eval data
    # TODO: reduce duplicated code here
    valid_samples = utils.read_dataset("valid", params["data_path"])
    logger.info("Read %d valid samples." % len(valid_samples))

    valid_data, valid_tensor_data = data.process_mention_data(
        valid_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params["context_key"],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )
    valid_sampler = SequentialSampler(valid_tensor_data)
    valid_dataloader = DataLoader(valid_tensor_data,
                                  sampler=valid_sampler,
                                  batch_size=eval_batch_size)

    # evaluate before training
    results = evaluate(
        reranker,
        valid_dataloader,
        params,
        device=device,
        logger=logger,
    )

    number_of_samples_per_dataset = {}

    time_start = time.time()

    utils.write_to_file(os.path.join(model_output_path, "training_params.txt"),
                        str(params))

    logger.info("Starting training")
    logger.info("device: {} n_gpu: {}, distributed training: {}".format(
        device, n_gpu, False))

    optimizer = get_optimizer(model, params)
    scheduler = get_scheduler(params, optimizer, len(train_tensor_data),
                              logger)

    model.train()

    best_epoch_idx = -1
    best_score = -1

    num_train_epochs = params["num_train_epochs"]
    for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):
        tr_loss = 0
        results = None

        if params["silent"]:
            iter_ = train_dataloader
        else:
            iter_ = tqdm(train_dataloader, desc="Batch")

        for step, batch in enumerate(iter_):
            batch = tuple(t.to(device) for t in batch)
            if params["zeshel"]:
                context_input, candidate_input, _, _ = batch
            else:
                context_input, candidate_input, _ = batch
            loss, _ = reranker(context_input, candidate_input)

            # if n_gpu > 1:
            #     loss = loss.mean() # mean() to average on multi-gpu.

            if grad_acc_steps > 1:
                loss = loss / grad_acc_steps

            tr_loss += loss.item()

            if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
                logger.info("Step {} - epoch {} average loss: {}\n".format(
                    step,
                    epoch_idx,
                    tr_loss / (params["print_interval"] * grad_acc_steps),
                ))
                tr_loss = 0

            loss.backward()

            if (step + 1) % grad_acc_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               params["max_grad_norm"])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
                logger.info("Evaluation on the development dataset")
                evaluate(
                    reranker,
                    valid_dataloader,
                    params,
                    device=device,
                    logger=logger,
                )
                model.train()
                logger.info("\n")

        logger.info("***** Saving fine - tuned model *****")
        epoch_output_folder_path = os.path.join(model_output_path,
                                                "epoch_{}".format(epoch_idx))
        utils.save_model(model, tokenizer, epoch_output_folder_path)

        output_eval_file = os.path.join(epoch_output_folder_path,
                                        "eval_results.txt")
        results = evaluate(
            reranker,
            valid_dataloader,
            params,
            device=device,
            logger=logger,
        )

        ls = [best_score, results["normalized_accuracy"]]
        li = [best_epoch_idx, epoch_idx]

        best_score = ls[np.argmax(ls)]
        best_epoch_idx = li[np.argmax(ls)]
        logger.info("\n")

    execution_time = (time.time() - time_start) / 60
    utils.write_to_file(
        os.path.join(model_output_path, "training_time.txt"),
        "The training took {} minutes\n".format(execution_time),
    )
    logger.info("The training took {} minutes\n".format(execution_time))

    # save the best model in the parent_dir
    logger.info("Best performance in epoch: {}".format(best_epoch_idx))
    params["path_to_model"] = os.path.join(model_output_path,
                                           "epoch_{}".format(best_epoch_idx))
    utils.save_model(reranker.model, tokenizer, model_output_path)

    if params["evaluate"]:
        params["path_to_model"] = model_output_path
        results = evaluate(
            reranker,
            valid_dataloader,
            params,
            device=device,
            logger=logger,
        )
示例#9
0
文件: test2.py 项目: helderarr/WS
        wikipedia_id2local_id,
        faiss_indexer,
    ) = _load_candidates(
        entity_catalogue,
        entity_encoding,
        faiss_index=None,
        index_path=None,
        logger=logger,
    )

    return biencoder, biencoder_params, candidate_encoding, faiss_indexer


top_k = 1

logger = utils.get_logger('output')
biencoder, biencoder_params, candidate_encoding, faiss_indexer = load_models(
    logger=logger)
ner_model = NER.get_model()
samples = _annotate(ner_model, [
    'What is throat cancer? Throat cancer is any cancer that forms in the throat. The throat, also called the pharynx, is a 5-inch-long tube that runs from your nose to your neck. The larynx (voice box) and pharynx are the two main places throat cancer forms. Throat cancer is a type of head and neck cancer, which includes cancer of the mouth, tonsils, nose, sinuses, salivary glands and neck lymph nodes.',
    "Juan Carlos is the king os Spain", "Cristiano Ronaldo has 5 Ballon D'Or"
])
print(samples)
dataloader = _process_biencoder_dataloader(samples, biencoder.tokenizer,
                                           biencoder_params)
labels, nns, scores = _run_biencoder(biencoder, dataloader, candidate_encoding,
                                     top_k, faiss_indexer)

print(labels, nns, scores)
示例#10
0
def main(params):
    model_output_path = params["output_path"]
    if not os.path.exists(model_output_path):
        os.makedirs(model_output_path)
    logger = utils.get_logger(params["output_path"])

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = model_output_path

    knn = params["knn"]
    use_types = params["use_types"]

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    device = reranker.device
    n_gpu = reranker.n_gpu

    if params["gradient_accumulation_steps"] < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(params["gradient_accumulation_steps"]))

    # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
    params["train_batch_size"] = (params["train_batch_size"] //
                                  params["gradient_accumulation_steps"])
    train_batch_size = params["train_batch_size"]
    eval_batch_size = params["eval_batch_size"]
    grad_acc_steps = params["gradient_accumulation_steps"]

    # Fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    entity_dictionary_loaded = False
    entity_dictionary_pkl_path = os.path.join(pickle_src_path,
                                              'entity_dictionary.pickle')
    if os.path.isfile(entity_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(entity_dictionary_pkl_path, 'rb') as read_handle:
            entity_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if not params["only_evaluate"]:
        # Load train data
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_processed_data_pkl_path = os.path.join(
            pickle_src_path, 'train_processed_data.pickle')
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_processed_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_processed_data_pkl_path, 'rb') as read_handle:
                train_processed_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset("train", params["data_path"])
            if not entity_dictionary_loaded:
                with open(
                        os.path.join(params["data_path"], 'dictionary.pickle'),
                        'rb') as read_handle:
                    entity_dictionary = pickle.load(read_handle)

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            if params["filter_unlabeled"]:
                # Filter samples without gold entities
                train_samples = list(
                    filter(
                        lambda sample: (len(sample["labels"]) > 0)
                        if mult_labels else (sample["label"] is not None),
                        train_samples))
            logger.info("Read %d train samples." % len(train_samples))

            # For discovery experiment: Drop entities used in training that were dropped randomly from dev/test set
            if params["drop_entities"]:
                assert entity_dictionary_loaded
                drop_set_pkl_path = os.path.join(
                    pickle_src_path, 'drop_set_mention_data.pickle')
                with open(drop_set_pkl_path, 'rb') as read_handle:
                    drop_set_data = pickle.load(read_handle)
                drop_set_mention_gold_cui_idxs = list(
                    map(lambda x: x['label_idxs'][0], drop_set_data))
                ents_in_data = np.unique(drop_set_mention_gold_cui_idxs)
                ent_drop_prop = 0.1
                logger.info(
                    f"Dropping {ent_drop_prop*100}% of {len(ents_in_data)} entities found in drop set"
                )
                # Get entity indices to drop
                n_ents_dropped = int(ent_drop_prop * len(ents_in_data))
                rng = np.random.default_rng(seed=17)
                dropped_ent_idxs = rng.choice(ents_in_data,
                                              size=n_ents_dropped,
                                              replace=False)

                # Drop entities from dictionary (subsequent processing will automatically drop corresponding mentions)
                keep_mask = np.ones(len(entity_dictionary), dtype='bool')
                keep_mask[dropped_ent_idxs] = False
                entity_dictionary = np.array(entity_dictionary)[keep_mask]

            train_processed_data, entity_dictionary, train_tensor_data = data_process.process_mention_data(
                train_samples,
                entity_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                context_key=params["context_key"],
                multi_label_key="labels" if mult_labels else None,
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            if not entity_dictionary_loaded:
                with open(entity_dictionary_pkl_path, 'wb') as write_handle:
                    pickle.dump(entity_dictionary,
                                write_handle,
                                protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_processed_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_processed_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store the query mention vectors
        train_men_vecs = train_tensor_data[:][0]

        if params["shuffle"]:
            train_sampler = RandomSampler(train_tensor_data)
        else:
            train_sampler = SequentialSampler(train_tensor_data)

        train_dataloader = DataLoader(train_tensor_data,
                                      sampler=train_sampler,
                                      batch_size=train_batch_size)

    # Store the entity dictionary vectors
    entity_dict_vecs = torch.tensor(list(
        map(lambda x: x['ids'], entity_dictionary)),
                                    dtype=torch.long)

    # Load eval data
    valid_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                              'valid_tensor_data.pickle')
    valid_processed_data_pkl_path = os.path.join(
        pickle_src_path, 'valid_processed_data.pickle')
    if os.path.isfile(valid_tensor_data_pkl_path) and os.path.isfile(
            valid_processed_data_pkl_path):
        print("Loading stored processed valid data...")
        with open(valid_tensor_data_pkl_path, 'rb') as read_handle:
            valid_tensor_data = pickle.load(read_handle)
        with open(valid_processed_data_pkl_path, 'rb') as read_handle:
            valid_processed_data = pickle.load(read_handle)
    else:
        valid_samples = utils.read_dataset("valid", params["data_path"])
        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in valid_samples[0].keys()
        # Filter samples without gold entities
        valid_samples = list(
            filter(
                lambda sample: (len(sample["labels"]) > 0) if mult_labels else
                (sample["label"] is not None), valid_samples))
        logger.info("Read %d valid samples." % len(valid_samples))

        valid_processed_data, _, valid_tensor_data = data_process.process_mention_data(
            valid_samples,
            entity_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            context_key=params["context_key"],
            multi_label_key="labels" if mult_labels else None,
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=True)
        print("Saving processed valid data...")
        with open(valid_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(valid_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(valid_processed_data_pkl_path, 'wb') as write_handle:
            pickle.dump(valid_processed_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store the query mention vectors
    valid_men_vecs = valid_tensor_data[:][0]

    # valid_sampler = SequentialSampler(valid_tensor_data)
    # valid_dataloader = DataLoader(
    #     valid_tensor_data, sampler=valid_sampler, batch_size=eval_batch_size
    # )

    if params["only_evaluate"]:
        evaluate(reranker,
                 entity_dict_vecs,
                 valid_men_vecs,
                 device=device,
                 logger=logger,
                 knn=knn,
                 n_gpu=n_gpu,
                 entity_data=entity_dictionary,
                 query_data=valid_processed_data,
                 silent=params["silent"],
                 use_types=use_types or params["use_types_for_eval"],
                 embed_batch_size=params['embed_batch_size'],
                 force_exact_search=use_types or params["use_types_for_eval"]
                 or params["force_exact_search"],
                 probe_mult_factor=params['probe_mult_factor'])
        exit()

    time_start = time.time()
    utils.write_to_file(os.path.join(model_output_path, "training_params.txt"),
                        str(params))
    logger.info("Starting training")
    logger.info("device: {} n_gpu: {}, data_parallel: {}".format(
        device, n_gpu, params["data_parallel"]))

    # Set model to training mode
    optimizer = get_optimizer(model, params)
    scheduler = get_scheduler(params, optimizer, len(train_tensor_data),
                              logger)
    best_epoch_idx = -1
    best_score = -1
    num_train_epochs = params["num_train_epochs"]

    init_base_model_run = True if params.get("path_to_model",
                                             None) is None else False
    init_run_pkl_path = os.path.join(
        pickle_src_path, f'init_run_{"type" if use_types else "notype"}.t7')

    dict_embed_data = None

    for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):
        model.train()
        torch.cuda.empty_cache()
        tr_loss = 0
        results = None

        # Check if embeddings and index can be loaded
        init_run_data_loaded = False
        if init_base_model_run:
            if os.path.isfile(init_run_pkl_path):
                logger.info('Loading init run data')
                init_run_data = torch.load(init_run_pkl_path)
                init_run_data_loaded = True
        load_stored_data = init_base_model_run and init_run_data_loaded

        # Compute mention and entity embeddings at the start of each epoch
        if use_types:
            if load_stored_data:
                train_dict_embeddings, dict_idxs_by_type = init_run_data[
                    'train_dict_embeddings'], init_run_data[
                        'dict_idxs_by_type']
                train_dict_indexes = data_process.get_index_from_embeds(
                    train_dict_embeddings,
                    dict_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, men_idxs_by_type = init_run_data[
                    'train_men_embeddings'], init_run_data['men_idxs_by_type']
                train_men_indexes = data_process.get_index_from_embeds(
                    train_men_embeddings,
                    men_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info('Embedding and indexing')
                if dict_embed_data is not None:
                    train_dict_embeddings, train_dict_indexes, dict_idxs_by_type = dict_embed_data[
                        'dict_embeds'], dict_embed_data[
                            'dict_indexes'], dict_embed_data[
                                'dict_idxs_by_type']
                else:
                    train_dict_embeddings, train_dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                        reranker,
                        entity_dict_vecs,
                        encoder_type="candidate",
                        n_gpu=n_gpu,
                        corpus=entity_dictionary,
                        force_exact_search=params['force_exact_search'],
                        batch_size=params['embed_batch_size'],
                        probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, train_men_indexes, men_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    train_men_vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    corpus=train_processed_data,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
        else:
            if load_stored_data:
                train_dict_embeddings = init_run_data['train_dict_embeddings']
                train_dict_index = data_process.get_index_from_embeds(
                    train_dict_embeddings,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings = init_run_data['train_men_embeddings']
                train_men_index = data_process.get_index_from_embeds(
                    train_men_embeddings,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info('Embedding and indexing')
                if dict_embed_data is not None:
                    train_dict_embeddings, train_dict_index = dict_embed_data[
                        'dict_embeds'], dict_embed_data['dict_index']
                else:
                    train_dict_embeddings, train_dict_index = data_process.embed_and_index(
                        reranker,
                        entity_dict_vecs,
                        encoder_type="candidate",
                        n_gpu=n_gpu,
                        force_exact_search=params['force_exact_search'],
                        batch_size=params['embed_batch_size'],
                        probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, train_men_index = data_process.embed_and_index(
                    reranker,
                    train_men_vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])

        # Save the initial embeddings and index if this is the first run and data isn't persistent
        if init_base_model_run and not load_stored_data:
            init_run_data = {}
            init_run_data['train_dict_embeddings'] = train_dict_embeddings
            init_run_data['train_men_embeddings'] = train_men_embeddings
            if use_types:
                init_run_data['dict_idxs_by_type'] = dict_idxs_by_type
                init_run_data['men_idxs_by_type'] = men_idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(init_run_data,
                       init_run_pkl_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        init_base_model_run = False

        if params["silent"]:
            iter_ = train_dataloader
        else:
            iter_ = tqdm(train_dataloader, desc="Batch")

        logger.info("Starting KNN search...")
        if not use_types:
            _, dict_nns = train_dict_index.search(train_men_embeddings, knn)
        else:
            dict_nns = np.zeros((len(train_men_embeddings), knn))
            for entity_type in train_men_indexes:
                men_embeds_by_type = train_men_embeddings[
                    men_idxs_by_type[entity_type]]
                _, dict_nns_by_type = train_dict_indexes[entity_type].search(
                    men_embeds_by_type, knn)
                dict_nns_idxs = np.array(
                    list(
                        map(lambda x: dict_idxs_by_type[entity_type][x],
                            dict_nns_by_type)))
                for i, idx in enumerate(men_idxs_by_type[entity_type]):
                    dict_nns[idx] = dict_nns_idxs[i]
        logger.info("Search finished")

        for step, batch in enumerate(iter_):
            batch = tuple(t.to(device) for t in batch)
            context_inputs, candidate_idxs, n_gold, mention_idxs = batch
            mention_embeddings = train_men_embeddings[mention_idxs.cpu()]

            if len(mention_embeddings.shape) == 1:
                mention_embeddings = np.expand_dims(mention_embeddings, axis=0)

            # context_inputs: Shape: batch x token_len
            candidate_inputs = np.array(
                [], dtype=np.int)  # Shape: (batch*knn) x token_len
            label_inputs = torch.tensor(
                [[1] + [0] * (knn - 1)] * n_gold.sum(),
                dtype=torch.float32)  # Shape: batch(with split rows) x knn
            context_inputs_split = torch.zeros(
                (label_inputs.size(0), context_inputs.size(1)),
                dtype=torch.long)  # Shape: batch(with split rows) x token_len
            # label_inputs = (candidate_idxs >= 0).type(torch.float32) # Shape: batch x knn

            for i, m_embed in enumerate(mention_embeddings):
                knn_dict_idxs = dict_nns[mention_idxs[i]]
                knn_dict_idxs = knn_dict_idxs.astype(np.int64).flatten()
                gold_idxs = candidate_idxs[i][:n_gold[i]].cpu()
                for ng, gold_idx in enumerate(gold_idxs):
                    context_inputs_split[i + ng] = context_inputs[i]
                    candidate_inputs = np.concatenate(
                        (candidate_inputs,
                         np.concatenate(
                             ([gold_idx],
                              knn_dict_idxs[~np.isin(knn_dict_idxs, gold_idxs)]
                              ))[:knn]))
            candidate_inputs = torch.tensor(
                list(
                    map(lambda x: entity_dict_vecs[x].numpy(),
                        candidate_inputs))).cuda()
            context_inputs_split = context_inputs_split.cuda()
            label_inputs = label_inputs.cuda()

            loss, _ = reranker(context_inputs_split,
                               candidate_inputs,
                               label_inputs,
                               pos_neg_loss=params["pos_neg_loss"])

            if grad_acc_steps > 1:
                loss = loss / grad_acc_steps

            tr_loss += loss.item()

            if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
                logger.info("Step {} - epoch {} average loss: {}\n".format(
                    step,
                    epoch_idx,
                    tr_loss / (params["print_interval"] * grad_acc_steps),
                ))
                tr_loss = 0

            loss.backward()

            if (step + 1) % grad_acc_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               params["max_grad_norm"])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
                logger.info("Evaluation on the development dataset")
                evaluate(reranker,
                         entity_dict_vecs,
                         valid_men_vecs,
                         device=device,
                         logger=logger,
                         knn=knn,
                         n_gpu=n_gpu,
                         entity_data=entity_dictionary,
                         query_data=valid_processed_data,
                         silent=params["silent"],
                         use_types=use_types or params["use_types_for_eval"],
                         embed_batch_size=params['embed_batch_size'],
                         force_exact_search=use_types
                         or params["use_types_for_eval"]
                         or params["force_exact_search"],
                         probe_mult_factor=params['probe_mult_factor'])
                model.train()
                logger.info("\n")

        logger.info("***** Saving fine-tuned model *****")
        epoch_output_folder_path = os.path.join(model_output_path,
                                                "epoch_{}".format(epoch_idx))
        utils.save_model(model, tokenizer, epoch_output_folder_path)
        logger.info(f"Model saved at {epoch_output_folder_path}")

        normalized_accuracy, dict_embed_data = evaluate(
            reranker,
            entity_dict_vecs,
            valid_men_vecs,
            device=device,
            logger=logger,
            knn=knn,
            n_gpu=n_gpu,
            entity_data=entity_dictionary,
            query_data=valid_processed_data,
            silent=params["silent"],
            use_types=use_types or params["use_types_for_eval"],
            embed_batch_size=params['embed_batch_size'],
            force_exact_search=use_types or params["use_types_for_eval"]
            or params["force_exact_search"],
            probe_mult_factor=params['probe_mult_factor'])

        ls = [best_score, normalized_accuracy]
        li = [best_epoch_idx, epoch_idx]

        best_score = ls[np.argmax(ls)]
        best_epoch_idx = li[np.argmax(ls)]
        logger.info("\n")

    execution_time = (time.time() - time_start) / 60
    utils.write_to_file(
        os.path.join(model_output_path, "training_time.txt"),
        "The training took {} minutes\n".format(execution_time),
    )
    logger.info("The training took {} minutes\n".format(execution_time))

    # save the best model in the parent_dir
    logger.info("Best performance in epoch: {}".format(best_epoch_idx))
    params["path_to_model"] = os.path.join(model_output_path,
                                           "epoch_{}".format(best_epoch_idx))
    utils.save_model(reranker.model, tokenizer, model_output_path)
    logger.info(f"Best model saved at {model_output_path}")
示例#11
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"])

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    model = reranker.model
    device = reranker.device
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    directed_graph = params["directed_graph"]
    use_types = params["use_types"]
    data_split = params["data_split"] # Parameter default is "test"

    # Load test data
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path, 'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path, 'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path, 'test_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'), 'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(filter(lambda sample: (len(sample["labels"]) > 0) if mult_labels else (sample["label"] is not None), test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded
        )
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data, write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data, write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(
        list(map(lambda x: x['ids'], test_dictionary)), dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    # Values of k to run the evaluation against
    knn_vals = [0] + [2**i for i in range(int(math.log(knn, 2)) + 1)]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if not params['only_recall'] and os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_entities+n_mentions, n_entities+n_mentions)
            }

        if use_types:
            print("Dictionary: Embedding and building index")
            dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(reranker, test_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, corpus=test_dictionary, force_exact_search=True)

            # Verifiy embeddings
            og_embeds = torch.load('models/trained/zeshel_og/eval/data_og/cand_encodes.t7')
            world_to_type = {12:'forgotten_realms', 13:'lego', 14:'star_trek', 15:'yugioh'}
            for world in og_embeds:
                for i,oge in enumerate(tqdm(og_embeds[world])):
                    dict_embed_idx = dict_idxs_by_type[world_to_type[world]][i]
                    try:
                        assert torch.eq(oge, torch.tensor(dict_embeds[dict_embed_idx]))
                    except:
                        embed()
                        exit()
            print('PASS')
            exit()
            print("Queries: Embedding and building index")
            men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(reranker, test_men_vecs, encoder_type="context", n_gpu=n_gpu, corpus=mention_data, force_exact_search=True)
        else:
            print("Dictionary: Embedding and building index")
            dict_embeds, dict_index = data_process.embed_and_index(
                reranker, test_dict_vecs, 'candidate', n_gpu=n_gpu)
            print("Queries: Embedding and building index")
            men_embeds, men_index = data_process.embed_and_index(
                reranker, test_men_vecs, 'context', n_gpu=n_gpu)

        recall_accuracy = {2**i: 0 for i in range(int(math.log(params['recall_k'], 2)) + 1)}
        recall_idxs = [0.]*params['recall_k']

        # Find the most similar entity and k-nn mentions for each mention query
        for men_query_idx, men_embed in enumerate(tqdm(men_embeds, total=len(men_embeds), desc="Fetching k-NN")):
            men_embed = np.expand_dims(men_embed, axis=0)
            
            dict_type_idx_mapping, men_type_idx_mapping = None, None
            if use_types:
                entity_type = mention_data[men_query_idx]['type']
                dict_index = dict_indexes[entity_type]
                men_index = men_indexes[entity_type]
                dict_type_idx_mapping = dict_idxs_by_type[entity_type]
                men_type_idx_mapping = men_idxs_by_type[entity_type]
            
            # Fetch nearest entity candidate
            gold_idxs = mention_data[men_query_idx]["label_idxs"][:mention_data[men_query_idx]["n_labels"]]
            dict_cand_idx, dict_cand_score, recall_idx = get_query_nn(
                1, dict_embeds, dict_index, men_embed, searchK=params['recall_k'], gold_idxs=gold_idxs, type_idx_mapping=dict_type_idx_mapping)
            # Compute recall metric
            if recall_idx > -1:
                recall_idxs[recall_idx] += 1.
                for recall_k in recall_accuracy:
                    if recall_idx < recall_k:
                        recall_accuracy[recall_k] += 1.

            if not params['only_recall']:
                # Fetch (k+1) NN mention candidates
                men_cand_idxs, men_cand_scores = get_query_nn(
                    max_knn + 1, men_embeds, men_index, men_embed, type_idx_mapping=men_type_idx_mapping)
                # Filter candidates to remove mention query and keep only the top k candidates
                filter_mask = men_cand_idxs != men_query_idx
                men_cand_idxs, men_cand_scores = men_cand_idxs[filter_mask][:max_knn], men_cand_scores[filter_mask][:max_knn]

                # Add edges to the graphs
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    # Add mention-entity edge
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'], [n_entities+men_query_idx])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(
                        joint_graph['cols'], dict_cand_idx)
                    joint_graph['data'] = np.append(
                        joint_graph['data'], dict_cand_score)
                    if k > 0:
                        # Add mention-mention edges
                        joint_graph['rows'] = np.append(
                            joint_graph['rows'], [n_entities+men_query_idx]*len(men_cand_idxs[:k]))
                        joint_graph['cols'] = np.append(
                            joint_graph['cols'], n_entities+men_cand_idxs[:k])
                        joint_graph['data'] = np.append(
                            joint_graph['data'], men_cand_scores[:k])

        # Compute and print recall metric
        recall_idx_mode = np.argmax(recall_idxs)
        recall_idx_mode_prop = recall_idxs[recall_idx_mode]/np.sum(recall_idxs)
        logger.info(f"""
        Recall metrics (for {len(men_embeds)} queries):
        ---------------""")
        logger.info(f"highest recall idx = {recall_idx_mode} ({recall_idxs[recall_idx_mode]}/{np.sum(recall_idxs)} = {recall_idx_mode_prop})")
        for recall_k in recall_accuracy:
            recall_accuracy[recall_k] /= len(men_embeds)
            logger.info(f"recall@{recall_k} = {recall_accuracy[recall_k]}")

        if params['only_recall']:
            exit()

        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs, write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    results = []
    for k in joint_graphs:
        print(f"\nGraph (k={k}):")
        # Partition graph based on cluster-linking constraints
        partitioned_graph, clusters = partition_graph(
            joint_graphs[k], n_entities, directed_graph, return_clusters=True)
        # Infer predictions from clusters
        result = analyzeClusters(clusters, test_dictionary, mention_data, k)
        # Store result
        results.append(result)

    # Store results
    output_file_name = os.path.join(
        output_path, f"eval_results_{__import__('calendar').timegm(__import__('time').gmtime())}")
    result_overview = {
        'n_entities': results[0]['n_entities'],
        'n_mentions': results[0]['n_mentions'],
        'directed': directed_graph
    }

    try:
        for recall_k in recall_accuracy:
            result_overview[f'recall@{recall_k}'] = recall_accuracy[recall_k]
    except:
        logger.info("Recall data not available since graphs were loaded from disk")
    
    for r in results:
        k = r['knn_mentions']
        result_overview[f'accuracy@knn{k}'] = r['accuracy']
        logger.info(f"accuracy@knn{k} = {r['accuracy']}")
        output_file = f'{output_file_name}-{k}.json'
        with open(output_file, 'w') as f:
            json.dump(r, f, indent=2)
            print(f"\nPredictions @knn{k} saved at: {output_file}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(result_overview, f, indent=2)
        print(f"\nPredictions overview saved at: {output_file_name}.json")
示例#12
0
def main(params):

    # create output dir
    eval_output_path = os.path.join(params["output_path"], "eval_output")
    if not os.path.exists(eval_output_path):
        os.makedirs(eval_output_path)
    # get logger
    logger = utils.get_logger(eval_output_path)

    # output command ran
    cmd = sys.argv
    cmd.insert(0, "python")
    logger.info(" ".join(cmd))

    # load the models
    assert params["path_to_model"] is None
    params["path_to_model"] = params["path_to_ctxt_model"]
    ctxt_reranker = CrossEncoderRanker(params)
    ctxt_model = ctxt_reranker.model

    params["pool_highlighted"] = False
    params["path_to_model"] = params["path_to_cand_model"]
    cand_reranker = CrossEncoderRanker(params)
    cand_model = cand_reranker.model

    params["path_to_model"] = None
    tokenizer = ctxt_reranker.tokenizer

    device = ctxt_reranker.device
    n_gpu = ctxt_reranker.n_gpu
    context_length = params["max_context_length"]

    # fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if ctxt_reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    # create eval dataloaders
    fname = os.path.join(params["data_path"],
                         "joint_" + params["mode"] + ".t7")
    eval_data = torch.load(fname)
    ctxt_dataloader = create_eval_dataloader(params, eval_data["contexts"],
                                             eval_data["context_uids"],
                                             eval_data["knn_ctxts"],
                                             eval_data["knn_ctxt_uids"])
    cand_dataloader = create_eval_dataloader(params, eval_data["contexts"],
                                             eval_data["context_uids"],
                                             eval_data["knn_cands"],
                                             eval_data["knn_cand_uids"])

    # construct ground truth data
    gold_linking_map, gold_coref_clusters = build_ground_truth(eval_data)

    # get all of the edges
    ctxt_edges, cand_edges = None, None
    ctxt_edges = score_contexts(
        ctxt_reranker,
        ctxt_dataloader,
        device=device,
        logger=logger,
        context_length=context_length,
        suffix="ctxt",
        silent=params["silent"],
    )
    cand_edges = score_contexts(
        cand_reranker,
        cand_dataloader,
        device=device,
        logger=logger,
        context_length=context_length,
        suffix="cand",
        silent=params["silent"],
    )

    # construct the sparse graphs
    sparse_shape = tuple(2 * [max(gold_linking_map.keys()) + 1])

    _ctxt_data = ctxt_edges[:, 2].cpu().numpy()
    _ctxt_row = ctxt_edges[:, 0].cpu().numpy()
    _ctxt_col = ctxt_edges[:, 1].cpu().numpy()
    ctxt_graph = coo_matrix((_ctxt_data, (_ctxt_row, _ctxt_col)),
                            shape=sparse_shape)

    _cand_data = cand_edges[:, 2].cpu().numpy()
    _cand_row = cand_edges[:, 1].cpu().numpy()
    _cand_col = cand_edges[:, 0].cpu().numpy()
    cand_graph = coo_matrix((_cand_data, (_cand_row, _cand_col)),
                            shape=sparse_shape)

    logger.info('Computing coref metrics...')
    coref_metrics = compute_coref_metrics(gold_coref_clusters, ctxt_graph)
    logger.info('Done.')

    logger.info('Computing linking metrics...')
    linking_metrics, slim_linking_graph = compute_linking_metrics(
        cand_graph, gold_linking_map)
    logger.info('Done.')

    logger.info('Computing joint metrics...')
    slim_coref_graph = _get_global_maximum_spanning_tree([ctxt_graph])
    joint_metrics = compute_joint_metrics(
        [slim_coref_graph, slim_linking_graph], gold_linking_map,
        min(gold_linking_map.keys()))
    logger.info('Done.')

    metrics = {
        'coref_fmi': coref_metrics['fmi'],
        'coref_rand_index': coref_metrics['rand_index'],
        'coref_threshold': coref_metrics['threshold'],
        'vanilla_recall': linking_metrics['vanilla_recall'],
        'vanilla_accuracy': linking_metrics['vanilla_accuracy'],
        'joint_accuracy': joint_metrics['joint_accuracy'],
        'joint_cc_recall': joint_metrics['joint_cc_recall']
    }

    logger.info('joint_metrics: {}'.format(
        json.dumps(metrics, sort_keys=True, indent=4)))

    # save all of the predictions for later analysis
    save_data = {}
    save_data.update(coref_metrics)
    save_data.update(linking_metrics)
    save_data.update(joint_metrics)

    save_fname = os.path.join(eval_output_path, 'results.t7')
    torch.save(save_data, save_fname)
示例#13
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"])

    # Init model 
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer

    # laod entities
    entity_dict, entity_json = load_entity_dict(logger, params)

    # load tfidf candidates
    tfidf_cand_dict = read_tfidf_cands(params["data_path"], params["mode"])

    # load mentions
    test_samples = utils.read_dataset(params["mode"], params["data_path"])
    logger.info("Read %d test samples." % len(test_samples))

    # get only the cands we need to tokenize
    cand_ids = [c for l in tfidf_cand_dict.values() for c in l]
    cand_ids.extend([x["label_umls_cuid"] for x in test_samples])
    cand_ids = list(set(cand_ids))
    num_cands = len(cand_ids)

    # tokenize the candidates
    cand_uid_map = {c : i for i, c in enumerate(cand_ids)}
    candidate_pool = get_candidate_pool_tensor(
        [entity_dict[c] for c in cand_ids],
        tokenizer,
        params["max_cand_length"],
        logger
    )

    # create mention maps
    ctxt_uid_map = {x["mm_mention_id"] : i + num_cands
                        for i, x in enumerate(test_samples)}
    ctxt_cand_map = {x["mm_mention_id"] : x["label_umls_cuid"]
                        for x in test_samples}
    ctxt_doc_map = {x["mm_mention_id"] : x["context_doc_id"]
                        for x in test_samples}
    doc_ctxt_map = defaultdict(list)
    for c, d in ctxt_doc_map.items():
        doc_ctxt_map[d].append(c)

    # create text maps for investigative evaluation
    uid_to_json = {
        uid : entity_json[cuid] for cuid, uid in cand_uid_map.items()
    }
    uid_to_json.update({i+num_cands : x for i, x in enumerate(test_samples)})

    # tokenize the contexts
    test_data, test_tensor_data = data.process_mention_data(
        test_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params['context_key'],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )
    context_pool = test_data["context_vecs"]
    
    # create output variables
    contexts = context_pool
    context_uids = torch.LongTensor(list(ctxt_uid_map.values()))

    pos_coref_ctxts = []
    pos_coref_ctxt_uids = []
    for i, c in enumerate(ctxt_doc_map.keys()):
        assert ctxt_uid_map[c] == i + num_cands
        doc = ctxt_doc_map[c]
        coref_ctxts = [x for x in doc_ctxt_map[doc]
                          if x != c and ctxt_cand_map[x] == ctxt_cand_map[c]]
        coref_ctxt_uids = [ctxt_uid_map[x] for x in coref_ctxts]
        coref_ctxt_idxs = [x - num_cands for x in coref_ctxt_uids]
        pos_coref_ctxts.append(context_pool[coref_ctxt_idxs])
        pos_coref_ctxt_uids.append(torch.LongTensor(coref_ctxt_uids))

    knn_ctxts = []
    knn_ctxt_uids = []
    for i, c in enumerate(ctxt_doc_map.keys()):
        assert ctxt_uid_map[c] == i + num_cands
        doc = ctxt_doc_map[c]
        wdoc_ctxts = [x for x in doc_ctxt_map[doc] if x != c]
        wdoc_ctxt_uids = [ctxt_uid_map[x] for x in wdoc_ctxts]
        wdoc_ctxt_idxs = [x - num_cands for x in wdoc_ctxt_uids]
        knn_ctxts.append(context_pool[wdoc_ctxt_idxs])
        knn_ctxt_uids.append(torch.LongTensor(wdoc_ctxt_uids))
        
    pos_cands = []
    pos_cand_uids = []
    for i, c in enumerate(ctxt_cand_map.keys()):
        assert ctxt_uid_map[c] == i + num_cands
        pos_cands.append(candidate_pool[cand_uid_map[ctxt_cand_map[c]]])
        pos_cand_uids.append(torch.LongTensor([cand_uid_map[ctxt_cand_map[c]]]))

    knn_cands = []
    knn_cand_uids = []
    for i, c in enumerate(ctxt_cand_map.keys()):
        assert ctxt_uid_map[c] == i + num_cands
        tfidf_cands = tfidf_cand_dict.get(c, [])
        tfidf_cand_uids = [cand_uid_map[x] for x in tfidf_cands]
        knn_cands.append(candidate_pool[tfidf_cand_uids])
        knn_cand_uids.append(torch.LongTensor(tfidf_cand_uids))

    tfidf_data = {
        "contexts" : contexts,
        "context_uids":  context_uids,
        "pos_coref_ctxts":  pos_coref_ctxts,
        "pos_coref_ctxt_uids":  pos_coref_ctxt_uids,
        "knn_ctxts":  knn_ctxts,
        "knn_ctxt_uids":  knn_ctxt_uids,
        "pos_cands":  pos_cands,
        "pos_cand_uids":  pos_cand_uids,
        "knn_cands":  knn_cands,
        "knn_cand_uids":  knn_cand_uids,
        "uid_to_json":  uid_to_json,
    }
    
    save_data_path = os.path.join(
        params['output_path'], 
        'joint_candidates_%s_tfidf.t7' % (params['mode'])
    )
    torch.save(tfidf_data, save_data_path)
def main(params):

    # create output dir
    eval_output_path = os.path.join(params["output_path"], "eval_output")
    if not os.path.exists(eval_output_path):
        os.makedirs(eval_output_path)
    # get logger
    logger = utils.get_logger(eval_output_path)

    # output command ran
    cmd = sys.argv
    cmd.insert(0, "python")
    logger.info(" ".join(cmd))

    params["pool_highlighted"] = False
    params["path_to_model"] = params["path_to_cand_model"]
    cand_reranker = CrossEncoderRanker(params)
    cand_model = cand_reranker.model

    params["path_to_model"] = None
    tokenizer = cand_reranker.tokenizer

    device = cand_reranker.device
    n_gpu = cand_reranker.n_gpu
    context_length = params["max_context_length"]

    # fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cand_reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    # create eval dataloaders
    fname = os.path.join(
        params["data_path"],
        "joint_" + params["mode"] +".t7"
    )
    eval_data = torch.load(fname)
    cand_dataloader = create_eval_dataloader(
        params,
        eval_data["contexts"],
        eval_data["context_uids"],
        eval_data["knn_cands"],
        eval_data["knn_cand_uids"]
    )

    # get all of the edges
    cand_edges = None

    dev_cache_path = os.path.join(eval_output_path, 'taggerOne_test_cand_edges.t7')
    if not os.path.exists(dev_cache_path):
        cand_edges = score_contexts(
            cand_reranker,
            cand_dataloader,
            device=device,
            logger=logger,
            context_length=context_length,
            suffix="cand",
            silent=params["silent"],
        )
    else:
        cand_edges = torch.load(dev_cache_path)

    # construct ground truth data
    gold_linking_map = build_ground_truth(eval_data)

    # compute TaggerOne pred metrics
    taggerOne_pred_metrics = get_taggerOne_metrics(eval_data)

    # construct the sparse graphs
    sparse_shape = tuple(2*[max(gold_linking_map.keys())+1])

    _cand_data = cand_edges[:, 2].cpu().numpy()
    _cand_row = cand_edges[:, 1].cpu().numpy()
    _cand_col = cand_edges[:, 0].cpu().numpy()
    cand_graph = coo_matrix(
        (_cand_data, (_cand_row, _cand_col)), shape=sparse_shape
    )

    logger.info('Computing linking metrics...')
    linking_metrics, slim_linking_graph = compute_linking_metrics(
        cand_graph, gold_linking_map
    )
    logger.info('Done.')

    metrics = {
        'e2e_taggerOne_cand_gen_recall' : linking_metrics['vanilla_recall'],
        'e2e_vanilla_precision' : linking_metrics['vanilla_accuracy'],
    }
    metrics.update(taggerOne_pred_metrics)

    logger.info('metrics: {}'.format(
        json.dumps(metrics, sort_keys=True, indent=4)
    ))

    # save all of the predictions for later analysis
    save_data = {}
    save_data.update(metrics)
    save_data.update(linking_metrics)

    save_fname = os.path.join(eval_output_path, 'taggerOne_test_results.t7')
    torch.save(save_data, save_fname)
示例#15
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"])

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    device = reranker.device

    cand_encode_path = params.get("cand_encode_path", None)

    # candidate encoding is not pre-computed.
    # load/generate candidate pool to compute candidate encoding.
    cand_pool_path = params.get("cand_pool_path", None)
    candidate_pool = load_or_generate_candidate_pool(
        tokenizer,
        params,
        logger,
        cand_pool_path,
    )

    candidate_encoding = None
    if cand_encode_path is not None:
        # try to load candidate encoding from path
        # if success, avoid computing candidate encoding
        try:
            logger.info("Loading pre-generated candidate encode path.")
            candidate_encoding = torch.load(cand_encode_path)
        except:
            logger.info("Loading failed. Generating candidate encoding.")

    if candidate_encoding is None:
        candidate_encoding = encode_candidate(reranker,
                                              candidate_pool,
                                              params["encode_batch_size"],
                                              silent=params["silent"],
                                              logger=logger,
                                              is_zeshel=params.get(
                                                  "zeshel", None))

        if cand_encode_path is not None:
            # Save candidate encoding to avoid re-compute
            logger.info("Saving candidate encoding to file " +
                        cand_encode_path)
            torch.save(candidate_encoding, cand_encode_path)

    test_samples = utils.read_dataset(params["mode"], params["data_path"])

    # test_samples_custom = utils.read_dataset(params["mode"], 'data/zeshel/processed')

    # # Copy custom dataset to original except 'label_id'
    # for i in range(len(test_samples)):
    #     for k in test_samples[i]:
    #         if k == 'label_id':
    #             continue
    #         k_custom = 'type' if k == 'world' else k
    #         test_samples[i][k] = test_samples_custom[i][k_custom]

    logger.info("Read %d test samples." % len(test_samples))

    test_data, test_tensor_data = data.process_mention_data(
        test_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params['context_key'],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )

    # custom_embeds = torch.load('models/trained/zeshel_og/eval/data_og/custom_embed.t7')
    # dict_idxs_by_type = torch.load('models/trained/zeshel_og/eval/data_og/dict_idx_mapping.t7')
    # test_dict_vecs = torch.load('models/trained/zeshel_og/eval/data_og/custom_test_ent_vecs.t7')
    # test_men_vecs = torch.load('models/trained/zeshel_og/eval/data_og/custom_test_men_vecs.t7')
    # world_to_type = {12:'forgotten_realms', 13:'lego', 14:'star_trek', 15:'yugioh'}
    # embed()

    test_sampler = SequentialSampler(test_tensor_data)
    test_dataloader = DataLoader(test_tensor_data,
                                 sampler=test_sampler,
                                 batch_size=params["encode_batch_size"])

    save_results = params.get("save_topk_result")
    new_data = nnquery.get_topk_predictions(
        reranker,
        test_dataloader,
        candidate_pool,
        candidate_encoding,
        params["silent"],
        logger,
        params["top_k"],
        params.get("zeshel", None),
        save_results,
    )

    if save_results:
        save_data_path = os.path.join(
            params['output_path'],
            'candidates_%s_top%d.t7' % (params['mode'], params['top_k']))
        torch.save(new_data, save_data_path)
示例#16
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]  # Use as the max-knn value for the graph construction
    use_types = params["use_types"]
    # within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    # if params['transductive']:
    #     train_tensor_data_pkl_path = os.path.join(pickle_src_path, 'train_tensor_data.pickle')
    #     train_mention_data_pkl_path = os.path.join(pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        # Filter samples without gold entities
        test_samples = list(
            filter(
                lambda sample: (len(sample["labels"]) > 0) if mult_labels else
                (sample["label"] is not None), test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Reducing the entity dictionary to only the ground truth of the mention queries
    # Combining the entities and mentions into one structure for joint embedding and indexing
    new_ents = {}
    new_ents_arr = []
    men_labels = []
    for men in mention_data:
        ent = men['label_idxs'][0]
        if ent not in new_ents:
            new_ents[ent] = len(new_ents_arr)
            new_ents_arr.append(ent)
        men_labels.append(new_ents[ent])
    ent_labels = [i for i in range(len(new_ents_arr))]
    new_ent_vecs = torch.tensor(
        list(map(lambda x: test_dictionary[x]['ids'], new_ents_arr)))
    new_ent_types = list(
        map(lambda x: {"type": test_dictionary[x]['type']}, new_ents_arr))
    test_men_vecs = test_tensor_data[:][0]

    n_mentions = len(test_tensor_data)
    n_entities = len(new_ent_vecs)
    n_embeds = n_mentions + n_entities
    leaf_labels = np.array(ent_labels + men_labels, dtype=int)
    all_vecs = torch.cat((new_ent_vecs, test_men_vecs))
    all_types = new_ent_types + mention_data  # Array of dicts containing key "type" for selected ents and all mentions

    # Values of k to run the evaluation against
    knn_vals = [25 * 2**i for i in range(int(math.log(knn / 25, 2)) + 1)
                ] if params["exact_knn"] is None else [params["exact_knn"]]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    time_start = time.time()

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_embeds, n_embeds)
            }

        # Check and load stored embedding data
        embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
        embed_data = None
        if os.path.isfile(embed_data_path):
            embed_data = torch.load(embed_data_path)
        if use_types:
            if embed_data is not None:
                logger.info('Loading stored embeddings')
                embeds = embed_data['embeds']
                if 'idxs_by_type' in embed_data:
                    idxs_by_type = embed_data['idxs_by_type']
                else:
                    idxs_by_type = data_process.get_idxs_by_type(all_types)
            else:
                logger.info("Embedding data")
                dict_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[:n_entities],
                    encoder_type='candidate',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                men_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[n_entities:],
                    encoder_type='context',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                embeds = np.concatenate((dict_embeds, men_embeds), axis=0)
                idxs_by_type = data_process.get_idxs_by_type(all_types)
            search_indexes = data_process.get_index_from_embeds(
                embeds, corpus_idxs=idxs_by_type, force_exact_search=True)
        else:
            if embed_data is not None:
                logger.info('Loading stored embeddings')
                embeds = embed_data['embeds']
            else:
                logger.info("Embedding data")
                dict_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[:n_entities],
                    encoder_type='candidate',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                men_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[n_entities:],
                    encoder_type='context',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                embeds = np.concatenate((dict_embeds, men_embeds), axis=0)
            search_index = data_process.get_index_from_embeds(
                embeds, force_exact_search=True)
        # Save computed embedding data if not loaded from disk
        if embed_data is None:
            embed_data = {}
            embed_data['embeds'] = embeds
            if use_types:
                embed_data['idxs_by_type'] = idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(embed_data,
                       embed_data_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        # Build faiss search index
        if params["normalize_embeds"]:
            embeds = normalize(embeds, axis=0)
        logger.info("Building KNN index...")
        if use_types:
            search_indexes = data_process.get_index_from_embeds(
                embeds, corpus_idxs=idxs_by_type, force_exact_search=True)
        else:
            search_index = data_process.get_index_from_embeds(
                embeds, force_exact_search=True)

        logger.info("Starting KNN search...")
        if not use_types:
            faiss_dists, faiss_idxs = search_index.search(embeds, max_knn + 1)
        else:
            query_len = n_embeds
            faiss_idxs = np.zeros((query_len, max_knn + 1))
            faiss_dists = np.zeros((query_len, max_knn + 1), dtype=float)
            for entity_type in search_indexes:
                embeds_by_type = embeds[idxs_by_type[entity_type]]
                nn_dists_by_type, nn_idxs_by_type = search_indexes[
                    entity_type].search(embeds_by_type, max_knn + 1)
                for i, idx in enumerate(idxs_by_type[entity_type]):
                    faiss_idxs[idx] = nn_idxs_by_type[i]
                    faiss_dists[idx] = nn_dists_by_type[i]
        logger.info("Search finished")

        logger.info('Building graphs')
        # Find the most similar nodes for each mention and node in the set (minus self)
        for idx in trange(n_embeds):
            # Compute adjacent node edge weight
            if idx != 0:
                adj_idx = idx - 1
                adj_data = embeds[adj_idx] @ embeds[idx]
            nn_idxs = faiss_idxs[idx]
            nn_scores = faiss_dists[idx]
            # Filter candidates to remove mention query and keep only the top k candidates
            filter_mask = nn_idxs != idx
            nn_idxs, nn_scores = nn_idxs[filter_mask][:max_knn], nn_scores[
                filter_mask][:max_knn]
            # Add edges to the graphs
            for k in joint_graphs:
                # Add edge to adjacent node to force the graph to be connected
                if idx != 0:
                    joint_graph['rows'] = np.append(joint_graph['rows'],
                                                    adj_idx)
                    joint_graph['cols'] = np.append(joint_graph['cols'], idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    adj_data)
                joint_graph = joint_graphs[k]
                # Add mention-mention edges
                joint_graph['rows'] = np.append(joint_graph['rows'], [idx] * k)
                joint_graph['cols'] = np.append(joint_graph['cols'],
                                                nn_idxs[:k])
                joint_graph['data'] = np.append(joint_graph['data'],
                                                nn_scores[:k])

        knn_fetch_time = time.time() - time_start
        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

        if params['only_embed_and_build']:
            logger.info(f"Saved embedding data at: {embed_data_path}")
            logger.info(f"Saved graphs at: {graph_path}")
            exit()

    results = {
        'n_leaves': n_embeds,
        'n_entities': n_entities,
        'n_mentions': n_mentions
    }

    graph_processing_time = time.time()
    n_graphs_processed = 0.
    linkage_fns = ["single", "complete", "average"] if params["linkage"] is None \
        else [params["linkage"]]  # Different HAC linkage functions to run the analyses over

    for fn in linkage_fns:
        logger.info(f"Linkage function: {fn}")
        purities = []
        fn_result = {}
        for k in joint_graphs:
            graph = hg.UndirectedGraph(n_embeds)
            graph.add_edges(joint_graphs[k]['rows'], joint_graphs[k]['cols'])
            weights = -joint_graphs[k][
                'data']  # Since Higra expects weights as distances, not similarity
            tree = get_hac_tree(graph, weights, linkage=fn)
            purity = hg.dendrogram_purity(tree, leaf_labels)
            fn_result[f"purity@{k}nn"] = purity
            logger.info(f"purity@{k}nn = {purity}")
            purities.append(purity)
            n_graphs_processed += 1
        fn_result["average"] = round(np.mean(purities), 4)
        logger.info(f"average = {fn_result['average']}")
        results[fn] = fn_result

    avg_graph_processing_time = (time.time() -
                                 graph_processing_time) / n_graphs_processed
    avg_per_graph_time = (knn_fetch_time + avg_graph_processing_time) / 60
    execution_time = (time.time() - time_start) / 60

    # Store results
    output_file_name = os.path.join(
        output_path,
        f"results_{__import__('calendar').timegm(__import__('time').gmtime())}"
    )

    logger.info(f"Results: \n {results}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(results, f, indent=2)
        print(f"\nResults saved at: {output_file_name}.json")

    logger.info("\nThe avg. per graph evaluation time is {} minutes\n".format(
        avg_per_graph_time))
    logger.info(
        "\nThe total evaluation took {} minutes\n".format(execution_time))
示例#17
0
def main(params):
    time_start = time.time()

    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log-discovery')

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    graph_path = params["graph_path"]
    if graph_path is None or not os.path.exists(graph_path):
        graph_path = output_path

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    rng = np.random.default_rng(seed=17)
    knn = params["knn"]
    use_types = params["use_types"]
    data_split = params["data_split"] # Default = "test"
    graph_mode = params.get('graph_mode', None)

    logger.info(f"Dataset: {data_split.upper()}")

    # Load evaluation data
    entity_dictionary_loaded = False
    dictionary_pkl_path = os.path.join(pickle_src_path, 'test_dictionary.pickle')
    tensor_data_pkl_path = os.path.join(pickle_src_path, 'test_tensor_data.pickle')
    mention_data_pkl_path = os.path.join(pickle_src_path, 'test_mention_data.pickle')
    print("Loading stored processed entity dictionary...")
    with open(dictionary_pkl_path, 'rb') as read_handle:
        dictionary = pickle.load(read_handle)
    print("Loading stored processed mention data...")
    with open(tensor_data_pkl_path, 'rb') as read_handle:
        tensor_data = pickle.load(read_handle)
    with open(mention_data_pkl_path, 'rb') as read_handle:
        mention_data = pickle.load(read_handle)

    # Load stored joint graphs
    graph_path = os.path.join(graph_path, 'graphs.pickle')
    print("Loading stored joint graphs...")
    with open(graph_path, 'rb') as read_handle:
        joint_graphs = pickle.load(read_handle)

    if not params['drop_all_entities']:
        # Since embed data is never used if the above condition is True
        print("Loading embed data...")
        # Check and load stored embedding data
        embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
        embed_data = torch.load(embed_data_path)

    n_entities = len(dictionary)
    seen_mention_idxs = set()
    unseen_mention_idxs_map = {}
    if params["seen_data_path"] is not None:  # Plug data leakage
        with open(params["seen_data_path"], 'rb') as read_handle:
            seen_data = pickle.load(read_handle)
        seen_cui_idxs = set()
        for seen_men in seen_data:
            seen_cui_idxs.add(seen_men['label_idxs'][0])
        logger.info(f"CUIs seen at training: {len(seen_cui_idxs)}")
        filtered_mention_data = []
        for menidx, men in enumerate(mention_data):
            if men['label_idxs'][0] not in seen_cui_idxs:
                filtered_mention_data.append(men)
                unseen_mention_idxs_map[menidx] = len(filtered_mention_data) - 1
            else:
                seen_mention_idxs.add(menidx)
        if not params['no_drop_seen']:
            logger.info("Dropping mentions whose CUIs were seen during training")
            logger.info(f"Unfiltered mention size: {len(mention_data)}")
            mention_data = filtered_mention_data
            logger.info(f"Filtered mention size: {len(mention_data)}")
    n_mentions = len(mention_data)
    n_labels = 1  # Zeshel and MedMentions have single gold entity mentions

    mention_gold_cui_idxs = list(map(lambda x: x['label_idxs'][n_labels - 1], mention_data))
    ents_in_data = np.unique(mention_gold_cui_idxs)

    if params['drop_all_entities']:
        ent_drop_prop = 1
        n_ents_dropped = len(ents_in_data)
        n_mentions_wo_gold_ents = n_mentions
        logger.info(f"Dropping all {n_ents_dropped} entities found in mention set")
        set_dropped_ent_idxs = set()
    else:
        # Percentage of entities from the mention set to drop
        ent_drop_prop = 0.1
        
        logger.info(f"Dropping {ent_drop_prop*100}% of {len(ents_in_data)} entities found in mention set")

        # Get entity indices to drop
        n_ents_dropped = int(ent_drop_prop*len(ents_in_data))
        dropped_ent_idxs = rng.choice(ents_in_data, size=n_ents_dropped, replace=False)
        set_dropped_ent_idxs = set(dropped_ent_idxs)
        
        n_mentions_wo_gold_ents = sum([1 if x in set_dropped_ent_idxs else 0 for x in mention_gold_cui_idxs])
        logger.info(f"Dropped {n_ents_dropped} entities")
        logger.info(f"=> Mentions without gold entities = {n_mentions_wo_gold_ents}")

        # Load embeddings in order to compute new KNN entities after dropping
        print('Computing new dictionary indexes...')
        
        original_dict_embeds = embed_data['dict_embeds']
        keep_mask = np.ones(len(original_dict_embeds), dtype='bool')
        keep_mask[dropped_ent_idxs] = False
        dict_embeds = original_dict_embeds[keep_mask]

        new_to_old_dict_mapping = []
        for i in range(len(original_dict_embeds)):
            if keep_mask[i]:
                new_to_old_dict_mapping.append(i)

        men_embeds = embed_data['men_embeds']
        if use_types:
            dict_idxs_by_type = data_process.get_idxs_by_type(list(compress(dictionary, keep_mask)))
            dict_indexes = data_process.get_index_from_embeds(dict_embeds, dict_idxs_by_type, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor'])
            if 'men_idxs_by_type' in embed_data:
                men_idxs_by_type = embed_data['men_idxs_by_type']
            else:
                men_idxs_by_type = data_process.get_idxs_by_type(mention_data)
        else:
            dict_index = data_process.get_index_from_embeds(dict_embeds, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor'])
        
        # Fetch additional KNN entity to make sure every mention has a linked entity after dropping
        extra_entity_knn = []
        if use_types:
            for men_type in men_idxs_by_type:
                dict_index = dict_indexes[men_type]
                dict_type_idx_mapping = dict_idxs_by_type[men_type]
                q_men_embeds = men_embeds[men_idxs_by_type[men_type]] # np.array(list(map(lambda x: men_embeds[x], men_idxs_by_type[men_type])))
                fetch_k = 1 if isinstance(dict_index, faiss.IndexFlatIP) else 16
                _, nn_idxs = dict_index.search(q_men_embeds, fetch_k)
                for i, men_idx in enumerate(men_idxs_by_type[men_type]):
                    r = n_entities + men_idx
                    q_nn_idxs = dict_type_idx_mapping[nn_idxs[i]]
                    q_nn_embeds = torch.tensor(dict_embeds[q_nn_idxs]).cuda()
                    q_scores = torch.flatten(
                        torch.mm(torch.tensor(q_men_embeds[i:i+1]).cuda(), q_nn_embeds.T)).cpu()
                    c, data = new_to_old_dict_mapping[q_nn_idxs[torch.argmax(q_scores)]], torch.max(q_scores)
                    extra_entity_knn.append((r,c,data))
        else:
            fetch_k = 1 if isinstance(dict_index, faiss.IndexFlatIP) else 16
            _, nn_idxs = dict_index.search(men_embeds, fetch_k)
            for men_idx, men_embed in enumerate(men_embeds):
                r = n_entities + men_idx
                q_nn_idxs = nn_idxs[men_idx]
                q_nn_embeds = torch.tensor(dict_embeds[q_nn_idxs]).cuda()
                q_scores = torch.flatten(
                    torch.mm(torch.tensor(np.expand_dims(men_embed, axis=0)).cuda(), q_nn_embeds.T)).cpu()
                c, data = new_to_old_dict_mapping[q_nn_idxs[torch.argmax(q_scores)]], torch.max(q_scores)
                extra_entity_knn.append((r,c,data))
        
        # Add entities for mentions whose 
        for k in joint_graphs:
            rows, cols, data= [], [], []
            for edge in extra_entity_knn:
                rows.append(edge[0])
                cols.append(edge[1])
                data.append(edge[2])
            joint_graphs[k]['rows'] = np.concatenate((joint_graphs[k]['rows'], rows))
            joint_graphs[k]['cols'] = np.concatenate((joint_graphs[k]['cols'], cols))
            joint_graphs[k]['data'] = np.concatenate((joint_graphs[k]['data'], data))

    results = {
        'data_split': data_split.upper(),
        'n_entities': n_entities,
        'n_mentions': n_mentions,
        'n_entities_dropped': f"{n_ents_dropped} ({ent_drop_prop*100}%)",
        'n_mentions_wo_gold_entities': n_mentions_wo_gold_ents
    }
    if graph_mode is None or graph_mode not in ['directed', 'undirected']:
        graph_mode = ['directed', 'undirected']
    else:
        graph_mode = [graph_mode]

    n_thresholds = params['n_thresholds'] # Default is 10
    exact_threshold = params.get('exact_threshold', None)
    exact_knn = params.get('exact_knn', None)
    
    kmeans = KMeans(n_clusters=n_thresholds, random_state=17)

    # TODO: Baseline? (without dropping entities)

    for mode in graph_mode:
        best_result = -1.
        best_config = None
        for k in joint_graphs:
            if params['drop_all_entities']:
                # Drop all entities from the graph
                rows, cols, data = joint_graphs[k]['rows'], joint_graphs[k]['cols'], joint_graphs[k]['data']
                _f_row, _f_col, _f_data = [], [], []
                for ki in range(len(joint_graphs[k]['rows'])):
                    if joint_graphs[k]['cols'][ki] < n_entities or joint_graphs[k]['rows'][ki] < n_entities:
                        continue
                    # Remove mentions whose gold entity was seen during training
                    if len(seen_mention_idxs) > 0 and not params['no_drop_seen']:
                        if (joint_graphs[k]['rows'][ki] - n_entities) in seen_mention_idxs or \
                                (joint_graphs[k]['cols'][ki] - n_entities) in seen_mention_idxs:
                            continue
                    _f_row.append(joint_graphs[k]['rows'][ki])
                    _f_col.append(joint_graphs[k]['cols'][ki])
                    _f_data.append(joint_graphs[k]['data'][ki])
                joint_graphs[k]['rows'], joint_graphs[k]['cols'], joint_graphs[k]['data'] = list(map(np.array, (_f_row, _f_col, _f_data)))
            if (exact_knn is None and k > 0 and k <= knn) or (exact_knn is not None and k == exact_knn):
                if exact_threshold is not None:
                    thresholds = np.array([0, exact_threshold])
                else:
                    thresholds = np.sort(np.concatenate(([0], kmeans.fit(joint_graphs[k]['data'].reshape(-1,1)).cluster_centers_.flatten())))
                for thresh in thresholds:
                    print("\nPartitioning...")
                    logger.info(f"{mode.upper()}, k={k}, threshold={thresh}")
                    # Partition graph based on cluster-linking constraints
                    partitioned_graph, clusters = partition_graph(
                        joint_graphs[k], n_entities, mode == 'directed', return_clusters=True, exclude=set_dropped_ent_idxs, threshold=thresh, without_entities=params['drop_all_entities'])
                    # Analyze cluster against gold clusters
                    result = analyzeClusters(clusters, mention_gold_cui_idxs, n_entities, n_mentions, logger, unseen_mention_idxs_map, no_drop_seen=params['no_drop_seen'])
                    results[f'({mode}, {k}, {thresh})'] = result
                    if not params['no_drop_seen']:
                        if thresh != 0 and result['average'] > best_result:
                            best_result = result['average']
                            best_config = (mode, k, thresh)
        if not params['no_drop_seen']:
            results[f'best_{mode}_config'] = best_config
            results[f'best_{mode}_result'] = best_result
    
    # Store results
    output_file_name = os.path.join(
        output_path, f"{data_split}_eval_discovery_{__import__('calendar').timegm(__import__('time').gmtime())}.json")

    with open(output_file_name, 'w') as f:
        json.dump(results, f, indent=2)
        print(f"\nAnalysis saved at: {output_file_name}")
    execution_time = (time.time() - time_start) / 60
    logger.info(f"\nTotal time taken: {execution_time} minutes\n")
示例#18
0
def main(params):
    # Parameter initializations
    logger = utils.get_logger(params["output_path"])
    global SCORING_BATCH_SIZE
    SCORING_BATCH_SIZE = params["scoring_batch_size"]
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path
    biencoder_indices_path = params["biencoder_indices_path"]
    if biencoder_indices_path is None:
        biencoder_indices_path = output_path
    elif not os.path.exists(biencoder_indices_path):
        os.makedirs(biencoder_indices_path)
    max_k = params["knn"]  # Maximum k-NN graph to build for evaluation
    use_types = params["use_types"]
    within_doc = params["within_doc"]
    discovery_mode = params["discovery"]

    # Bi-encoder model
    biencoder_params = copy.deepcopy(params)
    biencoder_params['add_linear'] = False
    bi_reranker = BiEncoderRanker(biencoder_params)
    bi_tokenizer = bi_reranker.tokenizer
    k_biencoder = params[
        "bi_knn"]  # Number of biencoder nearest-neighbors to fetch for cross-encoder scoring (default: 64)

    # Cross-encoder model
    params['add_linear'] = True
    params['add_sigmoid'] = True
    cross_reranker = CrossEncoderRanker(params)
    n_gpu = cross_reranker.n_gpu
    cross_reranker.model.eval()

    # Input lengths
    max_seq_length = params["max_seq_length"]
    max_context_length = params["max_context_length"]
    max_cand_length = params["max_cand_length"]

    # Fix random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cross_reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    # The below code is to generate the candidates for cross-encoder training and inference
    if params["save_topk_result"]:
        save_topk_biencoder_cands(bi_reranker,
                                  use_types,
                                  logger,
                                  n_gpu,
                                  params,
                                  bi_tokenizer,
                                  max_context_length,
                                  max_cand_length,
                                  pickle_src_path,
                                  topk=64)
        exit()

    data_split = params["data_split"]
    entity_dictionary, tensor_data, processed_data = load_data(
        data_split, bi_tokenizer, max_context_length, max_cand_length, max_k,
        pickle_src_path, params, logger)
    n_entities = len(entity_dictionary)
    n_mentions = len(processed_data)
    # Store dictionary vectors
    dict_vecs = torch.tensor(list(map(lambda x: x['ids'], entity_dictionary)),
                             dtype=torch.long)
    # Store query vectors
    men_vecs = tensor_data[:][0]

    discovery_entities = []
    if discovery_mode:
        discovery_entities, n_ents_dropped, n_mentions_wo_gold_ents, mention_gold_entities = get_entity_idxs_to_drop(
            processed_data, params, logger)

    context_doc_ids = None
    if within_doc:
        # Get context_document_ids for each mention in training and validation
        context_doc_ids = get_context_doc_ids(data_split, params)
    params[
        "only_evaluate"] = True  # Needed to call get_biencoder_nns() correctly
    _, biencoder_nns = get_biencoder_nns(
        bi_reranker=bi_reranker,
        biencoder_indices_path=biencoder_indices_path,
        entity_dictionary=entity_dictionary,
        entity_dict_vecs=dict_vecs,
        train_men_vecs=None,
        train_processed_data=None,
        train_gold_clusters=None,
        valid_men_vecs=men_vecs,
        valid_processed_data=processed_data,
        use_types=use_types,
        logger=logger,
        n_gpu=n_gpu,
        params=params,
        train_context_doc_ids=None,
        valid_context_doc_ids=context_doc_ids)
    bi_men_idxs = biencoder_nns['men_nns'][:, :k_biencoder]
    bi_ent_idxs = biencoder_nns['dict_nns'][:, :k_biencoder]
    bi_nn_count = np.sum(biencoder_nns['men_nns'] != -1, axis=1)

    # Compute and store the concatenated cross-encoder inputs for validation
    men_concat_inputs, ent_concat_inputs = build_cross_concat_input(
        biencoder_nns, men_vecs, dict_vecs, max_seq_length, k_biencoder)

    # Values of k to run the evaluation against
    knn_vals = [0] + [2**i for i in range(int(math.log(max_k, 2)) + 1)]
    max_k = knn_vals[-1]  # Store the maximum evaluation k
    bi_recall = None

    time_start = time.time()
    # Check if k-NN graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if not params['only_recall'] and os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_entities + n_mentions, n_entities + n_mentions)
            }

        # Score biencoder NNs using cross-encoder
        score_path = os.path.join(output_path, 'cross_scores_indexes.pickle')
        if os.path.isfile(score_path):
            print("Loading stored cross-encoder scores and indexes...")
            with open(score_path, 'rb') as read_handle:
                score_data = pickle.load(read_handle)
            cross_men_topk_idxs = score_data['cross_men_topk_idxs']
            cross_men_topk_scores = score_data['cross_men_topk_scores']
            cross_ent_top1_idx = score_data['cross_ent_top1_idx']
            cross_ent_top1_score = score_data['cross_ent_top1_score']
        else:
            with torch.no_grad():
                logger.info(
                    'Eval: Scoring mention-mention edges using cross-encoder...'
                )
                cross_men_scores = score_in_batches(
                    cross_reranker,
                    max_context_length,
                    men_concat_inputs,
                    is_context_encoder=True,
                    scoring_batch_size=SCORING_BATCH_SIZE)
                for i in range(len(cross_men_scores)):
                    # Set scores for all invalid nearest neighbours to -infinity (due to variable NN counts of mentions)
                    cross_men_scores[i][bi_nn_count[i]:] = float('-inf')
                cross_men_topk_scores, cross_men_topk_idxs = torch.sort(
                    cross_men_scores, dim=1, descending=True)
                cross_men_topk_idxs = cross_men_topk_idxs.cpu()[:, :max_k]
                cross_men_topk_scores = cross_men_topk_scores.cpu()[:, :max_k]
                logger.info('Eval: Scoring done')

                logger.info(
                    'Eval: Scoring mention-entity edges using cross-encoder...'
                )
                cross_ent_scores = score_in_batches(
                    cross_reranker,
                    max_context_length,
                    ent_concat_inputs,
                    is_context_encoder=False,
                    scoring_batch_size=SCORING_BATCH_SIZE)
                cross_ent_top1_score, cross_ent_top1_idx = torch.sort(
                    cross_ent_scores, dim=1, descending=True)
                cross_ent_top1_idx = cross_ent_top1_idx.cpu()
                cross_ent_top1_score = cross_ent_top1_score.cpu()
                if discovery_mode:
                    # Replace the first value in each row with an entity not in the drop set
                    for i in range(cross_ent_top1_idx.shape[0]):
                        for j in range(cross_ent_top1_idx.shape[1]):
                            if cross_ent_top1_idx[i,
                                                  j] not in discovery_entities:
                                cross_ent_top1_idx[i,
                                                   0] = cross_ent_top1_idx[i,
                                                                           j]
                                cross_ent_top1_score[
                                    i, 0] = cross_ent_top1_score[i, j]
                                break
                cross_ent_top1_idx = cross_ent_top1_idx[:, 0]
                cross_ent_top1_score = cross_ent_top1_score[:, 0]
                logger.info('Eval: Scoring done')
            # Pickle the scores and nearest indexes
            logger.info("Saving cross-encoder scores and indexes...")
            with open(score_path, 'wb') as write_handle:
                pickle.dump(
                    {
                        'cross_men_topk_idxs': cross_men_topk_idxs,
                        'cross_men_topk_scores': cross_men_topk_scores,
                        'cross_ent_top1_idx': cross_ent_top1_idx,
                        'cross_ent_top1_score': cross_ent_top1_score
                    },
                    write_handle,
                    protocol=pickle.HIGHEST_PROTOCOL)
            logger.info(f"Saved at: {score_path}")

        # Build k-NN graphs
        bi_recall = 0.
        for men_idx in tqdm(range(len(processed_data)),
                            total=len(processed_data),
                            desc="Eval: Building graphs"):
            # Track biencoder recall@<k_biencoder>
            gold_idx = processed_data[men_idx]["label_idxs"][0]
            if gold_idx in bi_ent_idxs[men_idx]:
                bi_recall += 1.
            # Get nearest entity
            m_e_idx = bi_ent_idxs[men_idx, cross_ent_top1_idx[men_idx]]
            m_e_score = cross_ent_top1_score[men_idx]
            if bi_nn_count[men_idx] > 0:
                # Get nearest mentions
                topk_defined_nn_idxs = cross_men_topk_idxs[
                    men_idx][:bi_nn_count[men_idx]]
                m_m_idxs = bi_men_idxs[
                    men_idx,
                    topk_defined_nn_idxs] + n_entities  # Mentions added at an offset of maximum entities
                m_m_scores = cross_men_topk_scores[
                    men_idx][:bi_nn_count[men_idx]]
            # Add edges to the graphs
            for k in joint_graphs:
                # Add mention-entity edge
                joint_graphs[k]['rows'] = np.append(
                    joint_graphs[k]['rows'],
                    [n_entities + men_idx
                     ])  # Mentions added at an offset of maximum entities
                joint_graphs[k]['cols'] = np.append(joint_graphs[k]['cols'],
                                                    m_e_idx)
                joint_graphs[k]['data'] = np.append(joint_graphs[k]['data'],
                                                    m_e_score)
                if k > 0 and bi_nn_count[men_idx] > 0:
                    # Add mention-mention edges
                    joint_graphs[k]['rows'] = np.append(
                        joint_graphs[k]['rows'],
                        [n_entities + men_idx] * len(m_m_idxs[:k]))
                    joint_graphs[k]['cols'] = np.append(
                        joint_graphs[k]['cols'], m_m_idxs[:k])
                    joint_graphs[k]['data'] = np.append(
                        joint_graphs[k]['data'], m_m_scores[:k])
        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        logger.info(f"Saved graphs at: {graph_path}")
        # Compute biencoder recall
        bi_recall /= len(processed_data)
        if params['only_recall']:
            logger.info(
                f"Eval: Biencoder recall@{k_biencoder} = {bi_recall * 100}%")
            exit()

    graph_mode = params.get('graph_mode', None)

    if discovery_mode:
        # Run the entity discovery experiment
        results = {
            'data_split': data_split.upper(),
            'n_entities': n_entities,
            'n_mentions': n_mentions,
            'n_entities_dropped':
            f"{n_ents_dropped} ({params['ent_drop_prop'] * 100}%)",
            'n_mentions_wo_gold_entities': n_mentions_wo_gold_ents
        }
        run_discovery_experiment(joint_graphs, discovery_entities,
                                 mention_gold_entities, n_entities, n_mentions,
                                 data_split, results, output_path, time_start,
                                 params, logger)
    else:
        # Run entity linking inference
        result_overview, results = {
            'n_entities': n_entities,
            'n_mentions': n_mentions
        }, {}
        if graph_mode is None or graph_mode not in ['directed', 'undirected']:
            results['directed'], results['undirected'] = [], []
        else:
            results[graph_mode] = []
        run_inference(entity_dictionary, processed_data, results,
                      result_overview, joint_graphs, n_entities, time_start,
                      output_path, bi_recall, k_biencoder, logger)
示例#19
0
文件: main_dense.py 项目: yyht/BLINK
        type=int,
        default=10,
        help="Number of candidates retrieved by biencoder.",
    )

    # output folder
    parser.add_argument(
        "--output_path",
        dest="output_path",
        type=str,
        default="output",
        help="Path to the output.",
    )

    parser.add_argument(
        "--fast", dest="fast", action="store_true", help="only biencoder mode"
    )
    
    parser.add_argument(
        "--show_url", dest="show_url", action="store_true", 
        help="whether to show entity url in interactive mode"
    )


    args = parser.parse_args()

    logger = utils.get_logger(args.output_path)

    models = load_models(args, logger)
    run(args, logger, *models)
示例#20
0
def main(params):

    # create output dir
    eval_output_path = os.path.join(params["output_path"], "eval_output")
    if not os.path.exists(eval_output_path):
        os.makedirs(eval_output_path)
    # get logger
    logger = utils.get_logger(eval_output_path)

    # output command ran
    cmd = sys.argv
    cmd.insert(0, "python")
    logger.info(" ".join(cmd))

    params["pool_highlighted"] = False
    params["path_to_model"] = params["path_to_cand_model"]
    cand_reranker = CrossEncoderRanker(params)
    cand_model = cand_reranker.model

    params["path_to_model"] = None
    tokenizer = cand_reranker.tokenizer

    device = cand_reranker.device
    n_gpu = cand_reranker.n_gpu
    context_length = params["max_context_length"]

    # fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cand_reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    # create eval dataloaders
    fname = os.path.join(params["data_path"],
                         "joint_" + params["mode"] + ".t7")
    eval_data = torch.load(fname)
    cand_dataloader = create_eval_dataloader(params, eval_data["contexts"],
                                             eval_data["context_uids"],
                                             eval_data["knn_cands"],
                                             eval_data["knn_cand_uids"])

    # construct ground truth data
    gold_linking_map, gold_coref_clusters = build_ground_truth(eval_data)

    # get uids we trained on
    train_data_fname = os.path.join(params["data_path"], "joint_train.t7")
    train_data = torch.load(train_data_fname)
    seen_uids = get_seen_uids(train_data, eval_data)

    # get all of the edges
    cand_edges = None
    cand_edges = score_contexts(
        cand_reranker,
        cand_dataloader,
        device=device,
        logger=logger,
        context_length=context_length,
        suffix="cand",
        silent=params["silent"],
    )

    # construct the sparse graphs
    sparse_shape = tuple(2 * [max(gold_linking_map.keys()) + 1])

    _cand_data = cand_edges[:, 2].cpu().numpy()
    _cand_row = cand_edges[:, 1].cpu().numpy()
    _cand_col = cand_edges[:, 0].cpu().numpy()
    cand_graph = coo_matrix((_cand_data, (_cand_row, _cand_col)),
                            shape=sparse_shape)

    logger.info('Computing linking metrics...')
    linking_metrics, slim_linking_graph = compute_linking_metrics(
        cand_graph, gold_linking_map, seen_uids=seen_uids)
    logger.info('Done.')

    uid_to_json = eval_data['uid_to_json']
    _cand_row = cand_graph.row
    _cand_col = cand_graph.col
    _cand_data = cand_graph.data
    _gt_row, _gt_col, _gt_data = [], [], []
    for r, c, d in zip(_cand_row, _cand_col, _cand_data):
        if uid_to_json[r]['type'] == uid_to_json[c]['type']:
            _gt_row.append(r)
            _gt_col.append(c)
            _gt_data.append(d)

    gold_type_cand_graph = coo_matrix((_gt_data, (_gt_row, _gt_col)),
                                      shape=sparse_shape)

    logger.info('Computing gold-type linking metrics...')
    gold_type_linking_metrics, slim_linking_graph = compute_linking_metrics(
        gold_type_cand_graph, gold_linking_map, seen_uids=seen_uids)
    logger.info('Done.')

    metrics = {
        'vanilla_recall': linking_metrics['vanilla_recall'],
        'vanilla_accuracy': linking_metrics['vanilla_accuracy'],
        'gold_type_vanilla_accuracy':
        gold_type_linking_metrics['vanilla_accuracy'],
        'seen_accuracy': linking_metrics['seen_accuracy'],
        'unseen_accuracy': linking_metrics['unseen_accuracy'],
        'gold_type_seen_accuracy': gold_type_linking_metrics['seen_accuracy'],
        'gold_type_unseen_accuracy':
        gold_type_linking_metrics['unseen_accuracy'],
    }

    logger.info('joint_metrics: {}'.format(
        json.dumps(metrics, sort_keys=True, indent=4)))

    # save all of the predictions for later analysis
    save_data = {}
    save_data.update(linking_metrics)
    gold_type_linking_metrics = {
        'gold_type_' + k: v
        for k, v in gold_type_linking_metrics.items()
    }
    save_data.update(gold_type_linking_metrics)

    save_fname = os.path.join(eval_output_path, 'results.t7')
    torch.save(save_data, save_fname)