示例#1
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log-eval')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    cands_path = os.path.join(output_path, 'cands.pickle')

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    use_types = params["use_types"]
    within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    test_samples = None
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    if params['transductive']:
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_mention_data_pkl_path = os.path.join(
            pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(
                filter(
                    lambda sample: (len(sample["labels"]) > 0)
                    if mult_labels else (sample["label"] is not None),
                    test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            entity_dictionary_loaded = True
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(list(map(lambda x: x['ids'],
                                           test_dictionary)),
                                  dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    if within_doc:
        if test_samples is None:
            test_samples, _ = read_data(data_split, params, logger)
        test_context_doc_ids = [s['context_doc_id'] for s in test_samples]

    if params["transductive"]:
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_mention_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_mention_data_pkl_path, 'rb') as read_handle:
                train_mention_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset('train', params["data_path"])

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            logger.info("Read %d test samples." % len(test_samples))

            train_mention_data, _, train_tensor_data = data_process.process_mention_data(
                train_samples,
                test_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                multi_label_key="labels" if mult_labels else None,
                context_key=params["context_key"],
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_mention_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_mention_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store train mention token ids
        train_men_vecs = train_tensor_data[:][0]
        n_mentions += len(train_tensor_data)
        n_train_mentions = len(train_tensor_data)

    # if os.path.isfile(cands_path):
    #     print("Loading stored candidates...")
    #     with open(cands_path, 'rb') as read_handle:
    #         men_cands = pickle.load(read_handle)
    # else:
    # Check and load stored embedding data
    embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
    embed_data = None
    if os.path.isfile(embed_data_path):
        embed_data = torch.load(embed_data_path)

    if use_types:
        if embed_data is not None:
            logger.info('Loading stored embeddings and computing indexes')
            dict_embeds = embed_data['dict_embeds']
            if 'dict_idxs_by_type' in embed_data:
                dict_idxs_by_type = embed_data['dict_idxs_by_type']
            else:
                dict_idxs_by_type = data_process.get_idxs_by_type(
                    test_dictionary)
            dict_indexes = data_process.get_index_from_embeds(
                dict_embeds,
                dict_idxs_by_type,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
            men_embeds = embed_data['men_embeds']
            if 'men_idxs_by_type' in embed_data:
                men_idxs_by_type = embed_data['men_idxs_by_type']
            else:
                men_idxs_by_type = data_process.get_idxs_by_type(mention_data)
            men_indexes = data_process.get_index_from_embeds(
                men_embeds,
                men_idxs_by_type,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
        else:
            logger.info("Dictionary: Embedding and building index")
            dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                reranker,
                test_dict_vecs,
                encoder_type="candidate",
                n_gpu=n_gpu,
                corpus=test_dictionary,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
            logger.info("Queries: Embedding and building index")
            vecs = test_men_vecs
            men_data = mention_data
            if params['transductive']:
                vecs = torch.cat((train_men_vecs, vecs), dim=0)
                men_data = train_mention_data + mention_data
            men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
                reranker,
                vecs,
                encoder_type="context",
                n_gpu=n_gpu,
                corpus=men_data,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
    else:
        if embed_data is not None:
            logger.info('Loading stored embeddings and computing indexes')
            dict_embeds = embed_data['dict_embeds']
            dict_index = data_process.get_index_from_embeds(
                dict_embeds,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
            men_embeds = embed_data['men_embeds']
            men_index = data_process.get_index_from_embeds(
                men_embeds,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
        else:
            logger.info("Dictionary: Embedding and building index")
            dict_embeds, dict_index = data_process.embed_and_index(
                reranker,
                test_dict_vecs,
                'candidate',
                n_gpu=n_gpu,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
            logger.info("Queries: Embedding and building index")
            vecs = test_men_vecs
            if params['transductive']:
                vecs = torch.cat((train_men_vecs, vecs), dim=0)
            men_embeds, men_index = data_process.embed_and_index(
                reranker,
                vecs,
                'context',
                n_gpu=n_gpu,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])

    # Save computed embedding data if not loaded from disk
    if embed_data is None:
        embed_data = {}
        embed_data['dict_embeds'] = dict_embeds
        embed_data['men_embeds'] = men_embeds
        if use_types:
            embed_data['dict_idxs_by_type'] = dict_idxs_by_type
            embed_data['men_idxs_by_type'] = men_idxs_by_type
        # NOTE: Cannot pickle faiss index because it is a SwigPyObject
        torch.save(embed_data,
                   embed_data_path,
                   pickle_protocol=pickle.HIGHEST_PROTOCOL)

    logger.info("Starting KNN search...")
    # Fetch (k+1) NN mention candidates
    if not use_types:
        _men_embeds = men_embeds
        if params['transductive']:
            _men_embeds = _men_embeds[n_train_mentions:]
        n_mens_to_fetch = len(_men_embeds) if within_doc else knn + 1
        nn_men_dists, nn_men_idxs = men_index.search(_men_embeds,
                                                     n_mens_to_fetch)
    else:
        query_len = len(men_embeds) - (n_train_mentions
                                       if params['transductive'] else 0)
        nn_men_idxs = -1 * np.ones((query_len, query_len), dtype=int)
        nn_men_dists = -1 * np.ones((query_len, query_len), dtype='float64')
        for entity_type in men_indexes:
            men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type][
                men_idxs_by_type[entity_type] >=
                n_train_mentions]] if params['transductive'] else men_embeds[
                    men_idxs_by_type[entity_type]]
            n_mens_to_fetch = len(
                men_embeds_by_type) if within_doc else knn + 1
            nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[
                entity_type].search(
                    men_embeds_by_type,
                    min(n_mens_to_fetch, len(men_embeds_by_type)))
            nn_men_idxs_by_type = np.array(
                list(
                    map(lambda x: men_idxs_by_type[entity_type][x],
                        nn_men_idxs_by_type)))
            i = -1
            for idx in men_idxs_by_type[entity_type]:
                if params['transductive']:
                    idx -= n_train_mentions
                if idx < 0:
                    continue
                i += 1
                nn_men_idxs[idx][:len(nn_men_idxs_by_type[i]
                                      )] = nn_men_idxs_by_type[i]
                nn_men_dists[idx][:len(nn_men_dists_by_type[i]
                                       )] = nn_men_dists_by_type[i]
    logger.info("Search finished")

    logger.info(f'Calculating mention recall@{knn}')
    # Find the most similar entity and k-nn mentions for each mention query
    men_recall_knn = []
    men_cands = []
    cui_sums = defaultdict(int)
    for m in mention_data:
        cui_sums[m['label_cuis'][0]] += 1
    for idx in range(len(nn_men_idxs)):
        filter_mask_neg1 = nn_men_idxs[idx] != -1
        men_cand_idxs = nn_men_idxs[idx][filter_mask_neg1]
        men_cand_scores = nn_men_dists[idx][filter_mask_neg1]

        if within_doc:
            men_cand_idxs, wd_mask = filter_by_context_doc_id(
                men_cand_idxs,
                test_context_doc_ids[idx],
                test_context_doc_ids,
                return_numpy=True)
            men_cand_scores = men_cand_scores[wd_mask]

        # Filter candidates to remove mention query and keep only the top k candidates
        filter_mask = men_cand_idxs != idx
        men_cand_idxs, men_cand_scores = men_cand_idxs[
            filter_mask][:knn], men_cand_scores[filter_mask][:knn]
        assert len(men_cand_idxs) == knn
        men_cands.append(men_cand_idxs)
        # Calculate mention recall@k
        gold_cui = mention_data[idx]['label_cuis'][0]
        if cui_sums[gold_cui] > 1:
            recall_hit = [
                1. for midx in men_cand_idxs
                if mention_data[midx]['label_cuis'][0] == gold_cui
            ]
            recall_knn = sum(recall_hit) / min(cui_sums[gold_cui] - 1, knn)
            men_recall_knn.append(recall_knn)
    logger.info('Done')
    # assert len(men_recall_knn) == len(nn_men_idxs)

    # Pickle the graphs
    print(f"Saving top-{knn} candidates")
    with open(cands_path, 'wb') as write_handle:
        pickle.dump(men_cands, write_handle, protocol=pickle.HIGHEST_PROTOCOL)

    # Output final results
    mean_recall_knn = np.mean(men_recall_knn)
    logger.info(f"Result: Final mean recall@{knn} = {mean_recall_knn}")
示例#2
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]  # Use as the max-knn value for the graph construction
    use_types = params["use_types"]
    # within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    # if params['transductive']:
    #     train_tensor_data_pkl_path = os.path.join(pickle_src_path, 'train_tensor_data.pickle')
    #     train_mention_data_pkl_path = os.path.join(pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        # Filter samples without gold entities
        test_samples = list(
            filter(
                lambda sample: (len(sample["labels"]) > 0) if mult_labels else
                (sample["label"] is not None), test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Reducing the entity dictionary to only the ground truth of the mention queries
    # Combining the entities and mentions into one structure for joint embedding and indexing
    new_ents = {}
    new_ents_arr = []
    men_labels = []
    for men in mention_data:
        ent = men['label_idxs'][0]
        if ent not in new_ents:
            new_ents[ent] = len(new_ents_arr)
            new_ents_arr.append(ent)
        men_labels.append(new_ents[ent])
    ent_labels = [i for i in range(len(new_ents_arr))]
    new_ent_vecs = torch.tensor(
        list(map(lambda x: test_dictionary[x]['ids'], new_ents_arr)))
    new_ent_types = list(
        map(lambda x: {"type": test_dictionary[x]['type']}, new_ents_arr))
    test_men_vecs = test_tensor_data[:][0]

    n_mentions = len(test_tensor_data)
    n_entities = len(new_ent_vecs)
    n_embeds = n_mentions + n_entities
    leaf_labels = np.array(ent_labels + men_labels, dtype=int)
    all_vecs = torch.cat((new_ent_vecs, test_men_vecs))
    all_types = new_ent_types + mention_data  # Array of dicts containing key "type" for selected ents and all mentions

    # Values of k to run the evaluation against
    knn_vals = [25 * 2**i for i in range(int(math.log(knn / 25, 2)) + 1)
                ] if params["exact_knn"] is None else [params["exact_knn"]]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    time_start = time.time()

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_embeds, n_embeds)
            }

        # Check and load stored embedding data
        embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
        embed_data = None
        if os.path.isfile(embed_data_path):
            embed_data = torch.load(embed_data_path)
        if use_types:
            if embed_data is not None:
                logger.info('Loading stored embeddings')
                embeds = embed_data['embeds']
                if 'idxs_by_type' in embed_data:
                    idxs_by_type = embed_data['idxs_by_type']
                else:
                    idxs_by_type = data_process.get_idxs_by_type(all_types)
            else:
                logger.info("Embedding data")
                dict_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[:n_entities],
                    encoder_type='candidate',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                men_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[n_entities:],
                    encoder_type='context',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                embeds = np.concatenate((dict_embeds, men_embeds), axis=0)
                idxs_by_type = data_process.get_idxs_by_type(all_types)
            search_indexes = data_process.get_index_from_embeds(
                embeds, corpus_idxs=idxs_by_type, force_exact_search=True)
        else:
            if embed_data is not None:
                logger.info('Loading stored embeddings')
                embeds = embed_data['embeds']
            else:
                logger.info("Embedding data")
                dict_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[:n_entities],
                    encoder_type='candidate',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                men_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[n_entities:],
                    encoder_type='context',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                embeds = np.concatenate((dict_embeds, men_embeds), axis=0)
            search_index = data_process.get_index_from_embeds(
                embeds, force_exact_search=True)
        # Save computed embedding data if not loaded from disk
        if embed_data is None:
            embed_data = {}
            embed_data['embeds'] = embeds
            if use_types:
                embed_data['idxs_by_type'] = idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(embed_data,
                       embed_data_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        # Build faiss search index
        if params["normalize_embeds"]:
            embeds = normalize(embeds, axis=0)
        logger.info("Building KNN index...")
        if use_types:
            search_indexes = data_process.get_index_from_embeds(
                embeds, corpus_idxs=idxs_by_type, force_exact_search=True)
        else:
            search_index = data_process.get_index_from_embeds(
                embeds, force_exact_search=True)

        logger.info("Starting KNN search...")
        if not use_types:
            faiss_dists, faiss_idxs = search_index.search(embeds, max_knn + 1)
        else:
            query_len = n_embeds
            faiss_idxs = np.zeros((query_len, max_knn + 1))
            faiss_dists = np.zeros((query_len, max_knn + 1), dtype=float)
            for entity_type in search_indexes:
                embeds_by_type = embeds[idxs_by_type[entity_type]]
                nn_dists_by_type, nn_idxs_by_type = search_indexes[
                    entity_type].search(embeds_by_type, max_knn + 1)
                for i, idx in enumerate(idxs_by_type[entity_type]):
                    faiss_idxs[idx] = nn_idxs_by_type[i]
                    faiss_dists[idx] = nn_dists_by_type[i]
        logger.info("Search finished")

        logger.info('Building graphs')
        # Find the most similar nodes for each mention and node in the set (minus self)
        for idx in trange(n_embeds):
            # Compute adjacent node edge weight
            if idx != 0:
                adj_idx = idx - 1
                adj_data = embeds[adj_idx] @ embeds[idx]
            nn_idxs = faiss_idxs[idx]
            nn_scores = faiss_dists[idx]
            # Filter candidates to remove mention query and keep only the top k candidates
            filter_mask = nn_idxs != idx
            nn_idxs, nn_scores = nn_idxs[filter_mask][:max_knn], nn_scores[
                filter_mask][:max_knn]
            # Add edges to the graphs
            for k in joint_graphs:
                # Add edge to adjacent node to force the graph to be connected
                if idx != 0:
                    joint_graph['rows'] = np.append(joint_graph['rows'],
                                                    adj_idx)
                    joint_graph['cols'] = np.append(joint_graph['cols'], idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    adj_data)
                joint_graph = joint_graphs[k]
                # Add mention-mention edges
                joint_graph['rows'] = np.append(joint_graph['rows'], [idx] * k)
                joint_graph['cols'] = np.append(joint_graph['cols'],
                                                nn_idxs[:k])
                joint_graph['data'] = np.append(joint_graph['data'],
                                                nn_scores[:k])

        knn_fetch_time = time.time() - time_start
        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

        if params['only_embed_and_build']:
            logger.info(f"Saved embedding data at: {embed_data_path}")
            logger.info(f"Saved graphs at: {graph_path}")
            exit()

    results = {
        'n_leaves': n_embeds,
        'n_entities': n_entities,
        'n_mentions': n_mentions
    }

    graph_processing_time = time.time()
    n_graphs_processed = 0.
    linkage_fns = ["single", "complete", "average"] if params["linkage"] is None \
        else [params["linkage"]]  # Different HAC linkage functions to run the analyses over

    for fn in linkage_fns:
        logger.info(f"Linkage function: {fn}")
        purities = []
        fn_result = {}
        for k in joint_graphs:
            graph = hg.UndirectedGraph(n_embeds)
            graph.add_edges(joint_graphs[k]['rows'], joint_graphs[k]['cols'])
            weights = -joint_graphs[k][
                'data']  # Since Higra expects weights as distances, not similarity
            tree = get_hac_tree(graph, weights, linkage=fn)
            purity = hg.dendrogram_purity(tree, leaf_labels)
            fn_result[f"purity@{k}nn"] = purity
            logger.info(f"purity@{k}nn = {purity}")
            purities.append(purity)
            n_graphs_processed += 1
        fn_result["average"] = round(np.mean(purities), 4)
        logger.info(f"average = {fn_result['average']}")
        results[fn] = fn_result

    avg_graph_processing_time = (time.time() -
                                 graph_processing_time) / n_graphs_processed
    avg_per_graph_time = (knn_fetch_time + avg_graph_processing_time) / 60
    execution_time = (time.time() - time_start) / 60

    # Store results
    output_file_name = os.path.join(
        output_path,
        f"results_{__import__('calendar').timegm(__import__('time').gmtime())}"
    )

    logger.info(f"Results: \n {results}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(results, f, indent=2)
        print(f"\nResults saved at: {output_file_name}.json")

    logger.info("\nThe avg. per graph evaluation time is {} minutes\n".format(
        avg_per_graph_time))
    logger.info(
        "\nThe total evaluation took {} minutes\n".format(execution_time))
示例#3
0
def save_topk_biencoder_cands(bi_reranker,
                              use_types,
                              logger,
                              n_gpu,
                              params,
                              bi_tokenizer,
                              max_context_length,
                              max_cand_length,
                              pickle_src_path,
                              topk=64):
    entity_dictionary = load_data('train',
                                  bi_tokenizer,
                                  max_context_length,
                                  max_cand_length,
                                  1,
                                  pickle_src_path,
                                  params,
                                  logger,
                                  return_dict_only=True)
    entity_dict_vecs = torch.tensor(list(
        map(lambda x: x['ids'], entity_dictionary)),
                                    dtype=torch.long)

    logger.info('Biencoder: Embedding and indexing entity dictionary')
    if use_types:
        _, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
            bi_reranker,
            entity_dict_vecs,
            encoder_type="candidate",
            n_gpu=n_gpu,
            corpus=entity_dictionary,
            force_exact_search=True,
            batch_size=params['embed_batch_size'])
    else:
        _, dict_index = data_process.embed_and_index(
            bi_reranker,
            entity_dict_vecs,
            encoder_type="candidate",
            n_gpu=n_gpu,
            force_exact_search=True,
            batch_size=params['embed_batch_size'])
    logger.info('Biencoder: Embedding and indexing finished')

    for mode in ["train", "valid", "test"]:
        logger.info(
            f"Biencoder: Fetching top-{topk} biencoder candidates for {mode} set"
        )
        _, tensor_data, processed_data = load_data(mode, bi_tokenizer,
                                                   max_context_length,
                                                   max_cand_length, 1,
                                                   pickle_src_path, params,
                                                   logger)
        men_vecs = tensor_data[:][0]

        logger.info('Biencoder: Embedding mention data')
        if use_types:
            men_embeddings, _, men_idxs_by_type = data_process.embed_and_index(
                bi_reranker,
                men_vecs,
                encoder_type="context",
                n_gpu=n_gpu,
                corpus=processed_data,
                force_exact_search=True,
                batch_size=params['embed_batch_size'])
        else:
            men_embeddings = data_process.embed_and_index(
                bi_reranker,
                men_vecs,
                encoder_type="context",
                n_gpu=n_gpu,
                force_exact_search=True,
                batch_size=params['embed_batch_size'],
                only_embed=True)
        logger.info('Biencoder: Embedding finished')

        logger.info("Biencoder: Finding nearest entities for each mention...")
        if not use_types:
            _, bi_dict_nns = dict_index.search(men_embeddings, topk)
        else:
            bi_dict_nns = np.zeros((len(men_embeddings), topk), dtype=int)
            for entity_type in men_idxs_by_type:
                men_embeds_by_type = men_embeddings[
                    men_idxs_by_type[entity_type]]
                _, dict_nns_by_type = dict_indexes[entity_type].search(
                    men_embeds_by_type, topk)
                dict_nns_idxs = np.array(
                    list(
                        map(lambda x: dict_idxs_by_type[entity_type][x],
                            dict_nns_by_type)))
                for i, idx in enumerate(men_idxs_by_type[entity_type]):
                    bi_dict_nns[idx] = dict_nns_idxs[i]
        logger.info("Biencoder: Search finished")

        labels = [-1] * len(bi_dict_nns)
        for men_idx in range(len(bi_dict_nns)):
            gold_idx = processed_data[men_idx]["label_idxs"][0]
            for i in range(len(bi_dict_nns[men_idx])):
                if bi_dict_nns[men_idx][i] == gold_idx:
                    labels[men_idx] = i
                    break

        logger.info(
            f"Biencoder: Saving top-{topk} biencoder candidates for {mode} set"
        )
        save_data_path = os.path.join(params['output_path'],
                                      f'candidates_{mode}_top{topk}.t7')
        torch.save({
            "mode": mode,
            "candidates": bi_dict_nns,
            "labels": labels
        }, save_data_path)
        logger.info("Biencoder: Saved")
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"])

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    model = reranker.model
    device = reranker.device
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    directed_graph = params["directed_graph"]
    use_types = params["use_types"]
    data_split = params["data_split"] # Parameter default is "test"

    # Load test data
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path, 'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path, 'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path, 'test_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'), 'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(filter(lambda sample: (len(sample["labels"]) > 0) if mult_labels else (sample["label"] is not None), test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded
        )
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data, write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data, write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(
        list(map(lambda x: x['ids'], test_dictionary)), dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    # Values of k to run the evaluation against
    knn_vals = [0] + [2**i for i in range(int(math.log(knn, 2)) + 1)]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if not params['only_recall'] and os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_entities+n_mentions, n_entities+n_mentions)
            }

        if use_types:
            print("Dictionary: Embedding and building index")
            dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(reranker, test_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, corpus=test_dictionary, force_exact_search=True)

            # Verifiy embeddings
            og_embeds = torch.load('models/trained/zeshel_og/eval/data_og/cand_encodes.t7')
            world_to_type = {12:'forgotten_realms', 13:'lego', 14:'star_trek', 15:'yugioh'}
            for world in og_embeds:
                for i,oge in enumerate(tqdm(og_embeds[world])):
                    dict_embed_idx = dict_idxs_by_type[world_to_type[world]][i]
                    try:
                        assert torch.eq(oge, torch.tensor(dict_embeds[dict_embed_idx]))
                    except:
                        embed()
                        exit()
            print('PASS')
            exit()
            print("Queries: Embedding and building index")
            men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(reranker, test_men_vecs, encoder_type="context", n_gpu=n_gpu, corpus=mention_data, force_exact_search=True)
        else:
            print("Dictionary: Embedding and building index")
            dict_embeds, dict_index = data_process.embed_and_index(
                reranker, test_dict_vecs, 'candidate', n_gpu=n_gpu)
            print("Queries: Embedding and building index")
            men_embeds, men_index = data_process.embed_and_index(
                reranker, test_men_vecs, 'context', n_gpu=n_gpu)

        recall_accuracy = {2**i: 0 for i in range(int(math.log(params['recall_k'], 2)) + 1)}
        recall_idxs = [0.]*params['recall_k']

        # Find the most similar entity and k-nn mentions for each mention query
        for men_query_idx, men_embed in enumerate(tqdm(men_embeds, total=len(men_embeds), desc="Fetching k-NN")):
            men_embed = np.expand_dims(men_embed, axis=0)
            
            dict_type_idx_mapping, men_type_idx_mapping = None, None
            if use_types:
                entity_type = mention_data[men_query_idx]['type']
                dict_index = dict_indexes[entity_type]
                men_index = men_indexes[entity_type]
                dict_type_idx_mapping = dict_idxs_by_type[entity_type]
                men_type_idx_mapping = men_idxs_by_type[entity_type]
            
            # Fetch nearest entity candidate
            gold_idxs = mention_data[men_query_idx]["label_idxs"][:mention_data[men_query_idx]["n_labels"]]
            dict_cand_idx, dict_cand_score, recall_idx = get_query_nn(
                1, dict_embeds, dict_index, men_embed, searchK=params['recall_k'], gold_idxs=gold_idxs, type_idx_mapping=dict_type_idx_mapping)
            # Compute recall metric
            if recall_idx > -1:
                recall_idxs[recall_idx] += 1.
                for recall_k in recall_accuracy:
                    if recall_idx < recall_k:
                        recall_accuracy[recall_k] += 1.

            if not params['only_recall']:
                # Fetch (k+1) NN mention candidates
                men_cand_idxs, men_cand_scores = get_query_nn(
                    max_knn + 1, men_embeds, men_index, men_embed, type_idx_mapping=men_type_idx_mapping)
                # Filter candidates to remove mention query and keep only the top k candidates
                filter_mask = men_cand_idxs != men_query_idx
                men_cand_idxs, men_cand_scores = men_cand_idxs[filter_mask][:max_knn], men_cand_scores[filter_mask][:max_knn]

                # Add edges to the graphs
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    # Add mention-entity edge
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'], [n_entities+men_query_idx])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(
                        joint_graph['cols'], dict_cand_idx)
                    joint_graph['data'] = np.append(
                        joint_graph['data'], dict_cand_score)
                    if k > 0:
                        # Add mention-mention edges
                        joint_graph['rows'] = np.append(
                            joint_graph['rows'], [n_entities+men_query_idx]*len(men_cand_idxs[:k]))
                        joint_graph['cols'] = np.append(
                            joint_graph['cols'], n_entities+men_cand_idxs[:k])
                        joint_graph['data'] = np.append(
                            joint_graph['data'], men_cand_scores[:k])

        # Compute and print recall metric
        recall_idx_mode = np.argmax(recall_idxs)
        recall_idx_mode_prop = recall_idxs[recall_idx_mode]/np.sum(recall_idxs)
        logger.info(f"""
        Recall metrics (for {len(men_embeds)} queries):
        ---------------""")
        logger.info(f"highest recall idx = {recall_idx_mode} ({recall_idxs[recall_idx_mode]}/{np.sum(recall_idxs)} = {recall_idx_mode_prop})")
        for recall_k in recall_accuracy:
            recall_accuracy[recall_k] /= len(men_embeds)
            logger.info(f"recall@{recall_k} = {recall_accuracy[recall_k]}")

        if params['only_recall']:
            exit()

        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs, write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    results = []
    for k in joint_graphs:
        print(f"\nGraph (k={k}):")
        # Partition graph based on cluster-linking constraints
        partitioned_graph, clusters = partition_graph(
            joint_graphs[k], n_entities, directed_graph, return_clusters=True)
        # Infer predictions from clusters
        result = analyzeClusters(clusters, test_dictionary, mention_data, k)
        # Store result
        results.append(result)

    # Store results
    output_file_name = os.path.join(
        output_path, f"eval_results_{__import__('calendar').timegm(__import__('time').gmtime())}")
    result_overview = {
        'n_entities': results[0]['n_entities'],
        'n_mentions': results[0]['n_mentions'],
        'directed': directed_graph
    }

    try:
        for recall_k in recall_accuracy:
            result_overview[f'recall@{recall_k}'] = recall_accuracy[recall_k]
    except:
        logger.info("Recall data not available since graphs were loaded from disk")
    
    for r in results:
        k = r['knn_mentions']
        result_overview[f'accuracy@knn{k}'] = r['accuracy']
        logger.info(f"accuracy@knn{k} = {r['accuracy']}")
        output_file = f'{output_file_name}-{k}.json'
        with open(output_file, 'w') as f:
            json.dump(r, f, indent=2)
            print(f"\nPredictions @knn{k} saved at: {output_file}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(result_overview, f, indent=2)
        print(f"\nPredictions overview saved at: {output_file_name}.json")
示例#5
0
def main(params):
    model_output_path = params["output_path"]
    if not os.path.exists(model_output_path):
        os.makedirs(model_output_path)
    logger = utils.get_logger(params["output_path"])

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = model_output_path

    knn = params["knn"]
    use_types = params["use_types"]

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    device = reranker.device
    n_gpu = reranker.n_gpu

    if params["gradient_accumulation_steps"] < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(params["gradient_accumulation_steps"]))

    # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
    params["train_batch_size"] = (params["train_batch_size"] //
                                  params["gradient_accumulation_steps"])
    train_batch_size = params["train_batch_size"]
    eval_batch_size = params["eval_batch_size"]
    grad_acc_steps = params["gradient_accumulation_steps"]

    # Fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    entity_dictionary_loaded = False
    entity_dictionary_pkl_path = os.path.join(pickle_src_path,
                                              'entity_dictionary.pickle')
    if os.path.isfile(entity_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(entity_dictionary_pkl_path, 'rb') as read_handle:
            entity_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if not params["only_evaluate"]:
        # Load train data
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_processed_data_pkl_path = os.path.join(
            pickle_src_path, 'train_processed_data.pickle')
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_processed_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_processed_data_pkl_path, 'rb') as read_handle:
                train_processed_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset("train", params["data_path"])
            if not entity_dictionary_loaded:
                with open(
                        os.path.join(params["data_path"], 'dictionary.pickle'),
                        'rb') as read_handle:
                    entity_dictionary = pickle.load(read_handle)

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            if params["filter_unlabeled"]:
                # Filter samples without gold entities
                train_samples = list(
                    filter(
                        lambda sample: (len(sample["labels"]) > 0)
                        if mult_labels else (sample["label"] is not None),
                        train_samples))
            logger.info("Read %d train samples." % len(train_samples))

            # For discovery experiment: Drop entities used in training that were dropped randomly from dev/test set
            if params["drop_entities"]:
                assert entity_dictionary_loaded
                drop_set_pkl_path = os.path.join(
                    pickle_src_path, 'drop_set_mention_data.pickle')
                with open(drop_set_pkl_path, 'rb') as read_handle:
                    drop_set_data = pickle.load(read_handle)
                drop_set_mention_gold_cui_idxs = list(
                    map(lambda x: x['label_idxs'][0], drop_set_data))
                ents_in_data = np.unique(drop_set_mention_gold_cui_idxs)
                ent_drop_prop = 0.1
                logger.info(
                    f"Dropping {ent_drop_prop*100}% of {len(ents_in_data)} entities found in drop set"
                )
                # Get entity indices to drop
                n_ents_dropped = int(ent_drop_prop * len(ents_in_data))
                rng = np.random.default_rng(seed=17)
                dropped_ent_idxs = rng.choice(ents_in_data,
                                              size=n_ents_dropped,
                                              replace=False)

                # Drop entities from dictionary (subsequent processing will automatically drop corresponding mentions)
                keep_mask = np.ones(len(entity_dictionary), dtype='bool')
                keep_mask[dropped_ent_idxs] = False
                entity_dictionary = np.array(entity_dictionary)[keep_mask]

            train_processed_data, entity_dictionary, train_tensor_data = data_process.process_mention_data(
                train_samples,
                entity_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                context_key=params["context_key"],
                multi_label_key="labels" if mult_labels else None,
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            if not entity_dictionary_loaded:
                with open(entity_dictionary_pkl_path, 'wb') as write_handle:
                    pickle.dump(entity_dictionary,
                                write_handle,
                                protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_processed_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_processed_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store the query mention vectors
        train_men_vecs = train_tensor_data[:][0]

        if params["shuffle"]:
            train_sampler = RandomSampler(train_tensor_data)
        else:
            train_sampler = SequentialSampler(train_tensor_data)

        train_dataloader = DataLoader(train_tensor_data,
                                      sampler=train_sampler,
                                      batch_size=train_batch_size)

    # Store the entity dictionary vectors
    entity_dict_vecs = torch.tensor(list(
        map(lambda x: x['ids'], entity_dictionary)),
                                    dtype=torch.long)

    # Load eval data
    valid_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                              'valid_tensor_data.pickle')
    valid_processed_data_pkl_path = os.path.join(
        pickle_src_path, 'valid_processed_data.pickle')
    if os.path.isfile(valid_tensor_data_pkl_path) and os.path.isfile(
            valid_processed_data_pkl_path):
        print("Loading stored processed valid data...")
        with open(valid_tensor_data_pkl_path, 'rb') as read_handle:
            valid_tensor_data = pickle.load(read_handle)
        with open(valid_processed_data_pkl_path, 'rb') as read_handle:
            valid_processed_data = pickle.load(read_handle)
    else:
        valid_samples = utils.read_dataset("valid", params["data_path"])
        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in valid_samples[0].keys()
        # Filter samples without gold entities
        valid_samples = list(
            filter(
                lambda sample: (len(sample["labels"]) > 0) if mult_labels else
                (sample["label"] is not None), valid_samples))
        logger.info("Read %d valid samples." % len(valid_samples))

        valid_processed_data, _, valid_tensor_data = data_process.process_mention_data(
            valid_samples,
            entity_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            context_key=params["context_key"],
            multi_label_key="labels" if mult_labels else None,
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=True)
        print("Saving processed valid data...")
        with open(valid_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(valid_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(valid_processed_data_pkl_path, 'wb') as write_handle:
            pickle.dump(valid_processed_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store the query mention vectors
    valid_men_vecs = valid_tensor_data[:][0]

    # valid_sampler = SequentialSampler(valid_tensor_data)
    # valid_dataloader = DataLoader(
    #     valid_tensor_data, sampler=valid_sampler, batch_size=eval_batch_size
    # )

    if params["only_evaluate"]:
        evaluate(reranker,
                 entity_dict_vecs,
                 valid_men_vecs,
                 device=device,
                 logger=logger,
                 knn=knn,
                 n_gpu=n_gpu,
                 entity_data=entity_dictionary,
                 query_data=valid_processed_data,
                 silent=params["silent"],
                 use_types=use_types or params["use_types_for_eval"],
                 embed_batch_size=params['embed_batch_size'],
                 force_exact_search=use_types or params["use_types_for_eval"]
                 or params["force_exact_search"],
                 probe_mult_factor=params['probe_mult_factor'])
        exit()

    time_start = time.time()
    utils.write_to_file(os.path.join(model_output_path, "training_params.txt"),
                        str(params))
    logger.info("Starting training")
    logger.info("device: {} n_gpu: {}, data_parallel: {}".format(
        device, n_gpu, params["data_parallel"]))

    # Set model to training mode
    optimizer = get_optimizer(model, params)
    scheduler = get_scheduler(params, optimizer, len(train_tensor_data),
                              logger)
    best_epoch_idx = -1
    best_score = -1
    num_train_epochs = params["num_train_epochs"]

    init_base_model_run = True if params.get("path_to_model",
                                             None) is None else False
    init_run_pkl_path = os.path.join(
        pickle_src_path, f'init_run_{"type" if use_types else "notype"}.t7')

    dict_embed_data = None

    for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):
        model.train()
        torch.cuda.empty_cache()
        tr_loss = 0
        results = None

        # Check if embeddings and index can be loaded
        init_run_data_loaded = False
        if init_base_model_run:
            if os.path.isfile(init_run_pkl_path):
                logger.info('Loading init run data')
                init_run_data = torch.load(init_run_pkl_path)
                init_run_data_loaded = True
        load_stored_data = init_base_model_run and init_run_data_loaded

        # Compute mention and entity embeddings at the start of each epoch
        if use_types:
            if load_stored_data:
                train_dict_embeddings, dict_idxs_by_type = init_run_data[
                    'train_dict_embeddings'], init_run_data[
                        'dict_idxs_by_type']
                train_dict_indexes = data_process.get_index_from_embeds(
                    train_dict_embeddings,
                    dict_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, men_idxs_by_type = init_run_data[
                    'train_men_embeddings'], init_run_data['men_idxs_by_type']
                train_men_indexes = data_process.get_index_from_embeds(
                    train_men_embeddings,
                    men_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info('Embedding and indexing')
                if dict_embed_data is not None:
                    train_dict_embeddings, train_dict_indexes, dict_idxs_by_type = dict_embed_data[
                        'dict_embeds'], dict_embed_data[
                            'dict_indexes'], dict_embed_data[
                                'dict_idxs_by_type']
                else:
                    train_dict_embeddings, train_dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                        reranker,
                        entity_dict_vecs,
                        encoder_type="candidate",
                        n_gpu=n_gpu,
                        corpus=entity_dictionary,
                        force_exact_search=params['force_exact_search'],
                        batch_size=params['embed_batch_size'],
                        probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, train_men_indexes, men_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    train_men_vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    corpus=train_processed_data,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
        else:
            if load_stored_data:
                train_dict_embeddings = init_run_data['train_dict_embeddings']
                train_dict_index = data_process.get_index_from_embeds(
                    train_dict_embeddings,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings = init_run_data['train_men_embeddings']
                train_men_index = data_process.get_index_from_embeds(
                    train_men_embeddings,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info('Embedding and indexing')
                if dict_embed_data is not None:
                    train_dict_embeddings, train_dict_index = dict_embed_data[
                        'dict_embeds'], dict_embed_data['dict_index']
                else:
                    train_dict_embeddings, train_dict_index = data_process.embed_and_index(
                        reranker,
                        entity_dict_vecs,
                        encoder_type="candidate",
                        n_gpu=n_gpu,
                        force_exact_search=params['force_exact_search'],
                        batch_size=params['embed_batch_size'],
                        probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, train_men_index = data_process.embed_and_index(
                    reranker,
                    train_men_vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])

        # Save the initial embeddings and index if this is the first run and data isn't persistent
        if init_base_model_run and not load_stored_data:
            init_run_data = {}
            init_run_data['train_dict_embeddings'] = train_dict_embeddings
            init_run_data['train_men_embeddings'] = train_men_embeddings
            if use_types:
                init_run_data['dict_idxs_by_type'] = dict_idxs_by_type
                init_run_data['men_idxs_by_type'] = men_idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(init_run_data,
                       init_run_pkl_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        init_base_model_run = False

        if params["silent"]:
            iter_ = train_dataloader
        else:
            iter_ = tqdm(train_dataloader, desc="Batch")

        logger.info("Starting KNN search...")
        if not use_types:
            _, dict_nns = train_dict_index.search(train_men_embeddings, knn)
        else:
            dict_nns = np.zeros((len(train_men_embeddings), knn))
            for entity_type in train_men_indexes:
                men_embeds_by_type = train_men_embeddings[
                    men_idxs_by_type[entity_type]]
                _, dict_nns_by_type = train_dict_indexes[entity_type].search(
                    men_embeds_by_type, knn)
                dict_nns_idxs = np.array(
                    list(
                        map(lambda x: dict_idxs_by_type[entity_type][x],
                            dict_nns_by_type)))
                for i, idx in enumerate(men_idxs_by_type[entity_type]):
                    dict_nns[idx] = dict_nns_idxs[i]
        logger.info("Search finished")

        for step, batch in enumerate(iter_):
            batch = tuple(t.to(device) for t in batch)
            context_inputs, candidate_idxs, n_gold, mention_idxs = batch
            mention_embeddings = train_men_embeddings[mention_idxs.cpu()]

            if len(mention_embeddings.shape) == 1:
                mention_embeddings = np.expand_dims(mention_embeddings, axis=0)

            # context_inputs: Shape: batch x token_len
            candidate_inputs = np.array(
                [], dtype=np.int)  # Shape: (batch*knn) x token_len
            label_inputs = torch.tensor(
                [[1] + [0] * (knn - 1)] * n_gold.sum(),
                dtype=torch.float32)  # Shape: batch(with split rows) x knn
            context_inputs_split = torch.zeros(
                (label_inputs.size(0), context_inputs.size(1)),
                dtype=torch.long)  # Shape: batch(with split rows) x token_len
            # label_inputs = (candidate_idxs >= 0).type(torch.float32) # Shape: batch x knn

            for i, m_embed in enumerate(mention_embeddings):
                knn_dict_idxs = dict_nns[mention_idxs[i]]
                knn_dict_idxs = knn_dict_idxs.astype(np.int64).flatten()
                gold_idxs = candidate_idxs[i][:n_gold[i]].cpu()
                for ng, gold_idx in enumerate(gold_idxs):
                    context_inputs_split[i + ng] = context_inputs[i]
                    candidate_inputs = np.concatenate(
                        (candidate_inputs,
                         np.concatenate(
                             ([gold_idx],
                              knn_dict_idxs[~np.isin(knn_dict_idxs, gold_idxs)]
                              ))[:knn]))
            candidate_inputs = torch.tensor(
                list(
                    map(lambda x: entity_dict_vecs[x].numpy(),
                        candidate_inputs))).cuda()
            context_inputs_split = context_inputs_split.cuda()
            label_inputs = label_inputs.cuda()

            loss, _ = reranker(context_inputs_split,
                               candidate_inputs,
                               label_inputs,
                               pos_neg_loss=params["pos_neg_loss"])

            if grad_acc_steps > 1:
                loss = loss / grad_acc_steps

            tr_loss += loss.item()

            if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
                logger.info("Step {} - epoch {} average loss: {}\n".format(
                    step,
                    epoch_idx,
                    tr_loss / (params["print_interval"] * grad_acc_steps),
                ))
                tr_loss = 0

            loss.backward()

            if (step + 1) % grad_acc_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               params["max_grad_norm"])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
                logger.info("Evaluation on the development dataset")
                evaluate(reranker,
                         entity_dict_vecs,
                         valid_men_vecs,
                         device=device,
                         logger=logger,
                         knn=knn,
                         n_gpu=n_gpu,
                         entity_data=entity_dictionary,
                         query_data=valid_processed_data,
                         silent=params["silent"],
                         use_types=use_types or params["use_types_for_eval"],
                         embed_batch_size=params['embed_batch_size'],
                         force_exact_search=use_types
                         or params["use_types_for_eval"]
                         or params["force_exact_search"],
                         probe_mult_factor=params['probe_mult_factor'])
                model.train()
                logger.info("\n")

        logger.info("***** Saving fine-tuned model *****")
        epoch_output_folder_path = os.path.join(model_output_path,
                                                "epoch_{}".format(epoch_idx))
        utils.save_model(model, tokenizer, epoch_output_folder_path)
        logger.info(f"Model saved at {epoch_output_folder_path}")

        normalized_accuracy, dict_embed_data = evaluate(
            reranker,
            entity_dict_vecs,
            valid_men_vecs,
            device=device,
            logger=logger,
            knn=knn,
            n_gpu=n_gpu,
            entity_data=entity_dictionary,
            query_data=valid_processed_data,
            silent=params["silent"],
            use_types=use_types or params["use_types_for_eval"],
            embed_batch_size=params['embed_batch_size'],
            force_exact_search=use_types or params["use_types_for_eval"]
            or params["force_exact_search"],
            probe_mult_factor=params['probe_mult_factor'])

        ls = [best_score, normalized_accuracy]
        li = [best_epoch_idx, epoch_idx]

        best_score = ls[np.argmax(ls)]
        best_epoch_idx = li[np.argmax(ls)]
        logger.info("\n")

    execution_time = (time.time() - time_start) / 60
    utils.write_to_file(
        os.path.join(model_output_path, "training_time.txt"),
        "The training took {} minutes\n".format(execution_time),
    )
    logger.info("The training took {} minutes\n".format(execution_time))

    # save the best model in the parent_dir
    logger.info("Best performance in epoch: {}".format(best_epoch_idx))
    params["path_to_model"] = os.path.join(model_output_path,
                                           "epoch_{}".format(best_epoch_idx))
    utils.save_model(reranker.model, tokenizer, model_output_path)
    logger.info(f"Best model saved at {model_output_path}")
示例#6
0
def evaluate(reranker,
             valid_dict_vecs,
             valid_men_vecs,
             device,
             logger,
             knn,
             n_gpu,
             entity_data,
             query_data,
             silent=False,
             use_types=False,
             embed_batch_size=768,
             force_exact_search=False,
             probe_mult_factor=1):
    torch.cuda.empty_cache()

    reranker.model.eval()
    n_entities = len(valid_dict_vecs)
    n_mentions = len(valid_men_vecs)
    joint_graphs = {}
    max_knn = 4
    for k in [0, 1, 2, 4]:
        joint_graphs[k] = {
            'rows': np.array([]),
            'cols': np.array([]),
            'data': np.array([]),
            'shape': (n_entities + n_mentions, n_entities + n_mentions)
        }

    if use_types:
        logger.info("Eval: Dictionary: Embedding and building index")
        dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
            reranker,
            valid_dict_vecs,
            encoder_type="candidate",
            n_gpu=n_gpu,
            corpus=entity_data,
            force_exact_search=force_exact_search,
            batch_size=embed_batch_size,
            probe_mult_factor=probe_mult_factor)
        logger.info("Eval: Queries: Embedding and building index")
        men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
            reranker,
            valid_men_vecs,
            encoder_type="context",
            n_gpu=n_gpu,
            corpus=query_data,
            force_exact_search=force_exact_search,
            batch_size=embed_batch_size,
            probe_mult_factor=probe_mult_factor)
    else:
        logger.info("Eval: Dictionary: Embedding and building index")
        dict_embeds, dict_index = data_process.embed_and_index(
            reranker,
            valid_dict_vecs,
            'candidate',
            n_gpu=n_gpu,
            force_exact_search=force_exact_search,
            batch_size=embed_batch_size,
            probe_mult_factor=probe_mult_factor)
        logger.info("Eval: Queries: Embedding and building index")
        men_embeds, men_index = data_process.embed_and_index(
            reranker,
            valid_men_vecs,
            'context',
            n_gpu=n_gpu,
            force_exact_search=force_exact_search,
            batch_size=embed_batch_size,
            probe_mult_factor=probe_mult_factor)

    logger.info("Eval: Starting KNN search...")
    # Fetch recall_k (default 16) knn entities for all mentions
    # Fetch (k+1) NN mention candidates
    if not use_types:
        nn_ent_dists, nn_ent_idxs = dict_index.search(men_embeds, 1)
        nn_men_dists, nn_men_idxs = men_index.search(men_embeds, max_knn + 1)
    else:
        nn_ent_idxs = np.zeros((len(men_embeds), 1))
        nn_ent_dists = np.zeros((len(men_embeds), 1), dtype='float64')
        nn_men_idxs = np.zeros((len(men_embeds), max_knn + 1))
        nn_men_dists = np.zeros((len(men_embeds), max_knn + 1),
                                dtype='float64')
        for entity_type in men_indexes:
            men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type]]
            nn_ent_dists_by_type, nn_ent_idxs_by_type = dict_indexes[
                entity_type].search(men_embeds_by_type, 1)
            nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[
                entity_type].search(men_embeds_by_type, max_knn + 1)
            nn_ent_idxs_by_type = np.array(
                list(
                    map(lambda x: dict_idxs_by_type[entity_type][x],
                        nn_ent_idxs_by_type)))
            nn_men_idxs_by_type = np.array(
                list(
                    map(lambda x: men_idxs_by_type[entity_type][x],
                        nn_men_idxs_by_type)))
            for i, idx in enumerate(men_idxs_by_type[entity_type]):
                nn_ent_idxs[idx] = nn_ent_idxs_by_type[i]
                nn_ent_dists[idx] = nn_ent_dists_by_type[i]
                nn_men_idxs[idx] = nn_men_idxs_by_type[i]
                nn_men_dists[idx] = nn_men_dists_by_type[i]
    logger.info("Eval: Search finished")

    logger.info('Eval: Building graphs')
    for men_query_idx, men_embed in enumerate(
            tqdm(men_embeds,
                 total=len(men_embeds),
                 desc="Eval: Building graphs")):
        # Get nearest entity candidate
        dict_cand_idx = nn_ent_idxs[men_query_idx][0]
        dict_cand_score = nn_ent_dists[men_query_idx][0]

        # Filter candidates to remove mention query and keep only the top k candidates
        men_cand_idxs = nn_men_idxs[men_query_idx]
        men_cand_scores = nn_men_dists[men_query_idx]

        filter_mask = men_cand_idxs != men_query_idx
        men_cand_idxs, men_cand_scores = men_cand_idxs[
            filter_mask][:max_knn], men_cand_scores[filter_mask][:max_knn]

        # Add edges to the graphs
        for k in joint_graphs:
            joint_graph = joint_graphs[k]
            # Add mention-entity edge
            joint_graph['rows'] = np.append(
                joint_graph['rows'],
                [n_entities + men_query_idx
                 ])  # Mentions added at an offset of maximum entities
            joint_graph['cols'] = np.append(joint_graph['cols'], dict_cand_idx)
            joint_graph['data'] = np.append(joint_graph['data'],
                                            dict_cand_score)
            if k > 0:
                # Add mention-mention edges
                joint_graph['rows'] = np.append(joint_graph['rows'],
                                                [n_entities + men_query_idx] *
                                                len(men_cand_idxs[:k]))
                joint_graph['cols'] = np.append(joint_graph['cols'],
                                                n_entities + men_cand_idxs[:k])
                joint_graph['data'] = np.append(joint_graph['data'],
                                                men_cand_scores[:k])

    max_eval_acc = -1.
    for k in joint_graphs:
        logger.info(f"\nEval: Graph (k={k}):")
        # Partition graph based on cluster-linking constraints
        partitioned_graph, clusters = eval_cluster_linking.partition_graph(
            joint_graphs[k], n_entities, directed=True, return_clusters=True)
        # Infer predictions from clusters
        result = eval_cluster_linking.analyzeClusters(clusters, entity_data,
                                                      query_data, k)
        acc = float(result['accuracy'].split(' ')[0])
        max_eval_acc = max(acc, max_eval_acc)
        logger.info(f"Eval: accuracy for graph@k={k}: {acc}%")
    logger.info(f"Eval: Best accuracy: {max_eval_acc}%")
    return max_eval_acc, {
        'dict_embeds': dict_embeds,
        'dict_indexes': dict_indexes,
        'dict_idxs_by_type': dict_idxs_by_type
    } if use_types else {
        'dict_embeds': dict_embeds,
        'dict_index': dict_index
    }
示例#7
0
def evaluate_ind_pred(reranker,
                      valid_dataloader,
                      valid_dict_vecs,
                      params,
                      device,
                      logger,
                      knn,
                      n_gpu,
                      entity_data,
                      query_data,
                      use_types=False,
                      embed_batch_size=768):
    reranker.model.eval()
    knn = max(
        16, 2 * knn
    )  # Accomodate the approximate-nature of the knn procedure by retrieving more samples and then filtering
    iter_ = valid_dataloader if params["silent"] else tqdm(valid_dataloader,
                                                           desc="Evaluation")
    eval_accuracy = 0.0
    nb_eval_examples = 0
    nb_eval_steps = 0

    if not use_types:
        valid_dict_embeddings, valid_dict_index = data_process.embed_and_index(
            reranker,
            valid_dict_vecs,
            encoder_type="candidate",
            n_gpu=n_gpu,
            batch_size=embed_batch_size)
    else:
        valid_dict_embeddings, valid_dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
            reranker,
            valid_dict_vecs,
            encoder_type="candidate",
            n_gpu=n_gpu,
            corpus=entity_data,
            batch_size=embed_batch_size)

    for step, batch in enumerate(iter_):
        batch = tuple(t.to(device) for t in batch)
        context_inputs, candidate_idxs, n_gold, mention_idxs = batch

        with torch.no_grad():
            mention_embeddings = reranker.encode_context(context_inputs)
            # context_inputs: Shape: batch x token_len
            candidate_inputs = np.array(
                [], dtype=np.int)  # Shape: (batch*knn) x token_len
            label_inputs = torch.zeros(
                (context_inputs.shape[0], knn),
                dtype=torch.float32)  # Shape: batch x knn

            for i, m_embed in enumerate(mention_embeddings):
                if use_types:
                    entity_type = query_data[mention_idxs[i]]['type']
                    valid_dict_index = valid_dict_indexes[entity_type]
                _, knn_dict_idxs = valid_dict_index.search(
                    np.expand_dims(m_embed, axis=0), knn)
                knn_dict_idxs = knn_dict_idxs.astype(np.int64).flatten()
                if use_types:
                    # Map type-specific indices to the entire dictionary
                    knn_dict_idxs = list(
                        map(lambda x: dict_idxs_by_type[entity_type][x],
                            knn_dict_idxs))
                gold_idxs = candidate_idxs[i][:n_gold[i]].cpu()
                candidate_inputs = np.concatenate(
                    (candidate_inputs, knn_dict_idxs))
                label_inputs[i] = torch.tensor(
                    [1 if nn in gold_idxs else 0 for nn in knn_dict_idxs])
            candidate_inputs = torch.tensor(
                list(
                    map(lambda x: valid_dict_vecs[x].numpy(),
                        candidate_inputs))).cuda()
            context_inputs = context_inputs.cuda()
            label_inputs = label_inputs.cuda()

            logits = reranker(context_inputs,
                              candidate_inputs,
                              label_inputs,
                              only_logits=True)

        logits = logits.detach().cpu().numpy()
        tmp_eval_accuracy = int(
            torch.sum(label_inputs[np.arange(label_inputs.shape[0]),
                                   np.argmax(logits, axis=1)] == 1))
        eval_accuracy += tmp_eval_accuracy
        nb_eval_examples += context_inputs.size(0)
        nb_eval_steps += 1

    normalized_eval_accuracy = eval_accuracy / nb_eval_examples
    logger.info("Eval accuracy: %.5f" % normalized_eval_accuracy)
    return normalized_eval_accuracy
示例#8
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log-eval')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    use_types = params["use_types"]
    within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    test_samples = None
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    if params['transductive']:
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_mention_data_pkl_path = os.path.join(
            pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(
                filter(
                    lambda sample: (len(sample["labels"]) > 0)
                    if mult_labels else (sample["label"] is not None),
                    test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            entity_dictionary_loaded = True
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(list(map(lambda x: x['ids'],
                                           test_dictionary)),
                                  dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    if within_doc:
        if test_samples is None:
            test_samples, _ = read_data(data_split, params, logger)
        test_context_doc_ids = [s['context_doc_id'] for s in test_samples]

    if params["transductive"]:
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_mention_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_mention_data_pkl_path, 'rb') as read_handle:
                train_mention_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset('train', params["data_path"])

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            logger.info("Read %d test samples." % len(test_samples))

            train_mention_data, _, train_tensor_data = data_process.process_mention_data(
                train_samples,
                test_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                multi_label_key="labels" if mult_labels else None,
                context_key=params["context_key"],
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_mention_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_mention_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store train mention token ids
        train_men_vecs = train_tensor_data[:][0]
        n_mentions += len(train_tensor_data)
        n_train_mentions = len(train_tensor_data)

    # Values of k to run the evaluation against
    knn_vals = [0] + [2**i for i in range(int(math.log(knn, 2)) + 1)]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    time_start = time.time()

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if not params['only_recall'] and os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_entities + n_mentions, n_entities + n_mentions)
            }

        # Check and load stored embedding data
        embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
        embed_data = None
        if os.path.isfile(embed_data_path):
            embed_data = torch.load(embed_data_path)

        if use_types:
            if embed_data is not None:
                logger.info('Loading stored embeddings and computing indexes')
                dict_embeds = embed_data['dict_embeds']
                if 'dict_idxs_by_type' in embed_data:
                    dict_idxs_by_type = embed_data['dict_idxs_by_type']
                else:
                    dict_idxs_by_type = data_process.get_idxs_by_type(
                        test_dictionary)
                dict_indexes = data_process.get_index_from_embeds(
                    dict_embeds,
                    dict_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                men_embeds = embed_data['men_embeds']
                if 'men_idxs_by_type' in embed_data:
                    men_idxs_by_type = embed_data['men_idxs_by_type']
                else:
                    men_idxs_by_type = data_process.get_idxs_by_type(
                        mention_data)
                men_indexes = data_process.get_index_from_embeds(
                    men_embeds,
                    men_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info("Dictionary: Embedding and building index")
                dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    test_dict_vecs,
                    encoder_type="candidate",
                    n_gpu=n_gpu,
                    corpus=test_dictionary,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
                logger.info("Queries: Embedding and building index")
                vecs = test_men_vecs
                men_data = mention_data
                if params['transductive']:
                    vecs = torch.cat((train_men_vecs, vecs), dim=0)
                    men_data = train_mention_data + mention_data
                men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    corpus=men_data,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
        else:
            if embed_data is not None:
                logger.info('Loading stored embeddings and computing indexes')
                dict_embeds = embed_data['dict_embeds']
                dict_index = data_process.get_index_from_embeds(
                    dict_embeds,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                men_embeds = embed_data['men_embeds']
                men_index = data_process.get_index_from_embeds(
                    men_embeds,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info("Dictionary: Embedding and building index")
                dict_embeds, dict_index = data_process.embed_and_index(
                    reranker,
                    test_dict_vecs,
                    'candidate',
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
                logger.info("Queries: Embedding and building index")
                vecs = test_men_vecs
                if params['transductive']:
                    vecs = torch.cat((train_men_vecs, vecs), dim=0)
                men_embeds, men_index = data_process.embed_and_index(
                    reranker,
                    vecs,
                    'context',
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])

        # Save computed embedding data if not loaded from disk
        if embed_data is None:
            embed_data = {}
            embed_data['dict_embeds'] = dict_embeds
            embed_data['men_embeds'] = men_embeds
            if use_types:
                embed_data['dict_idxs_by_type'] = dict_idxs_by_type
                embed_data['men_idxs_by_type'] = men_idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(embed_data,
                       embed_data_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        recall_accuracy = {
            2**i: 0
            for i in range(int(math.log(params['recall_k'], 2)) + 1)
        }
        recall_idxs = [0.] * params['recall_k']

        logger.info("Starting KNN search...")
        # Fetch recall_k (default 16) knn entities for all mentions
        # Fetch (k+1) NN mention candidates
        if not use_types:
            _men_embeds = men_embeds
            if params['transductive']:
                _men_embeds = _men_embeds[n_train_mentions:]
            nn_ent_dists, nn_ent_idxs = dict_index.search(
                _men_embeds, params['recall_k'])
            n_mens_to_fetch = len(_men_embeds) if within_doc else max_knn + 1
            nn_men_dists, nn_men_idxs = men_index.search(
                _men_embeds, n_mens_to_fetch)
        else:
            query_len = len(men_embeds) - (n_train_mentions
                                           if params['transductive'] else 0)
            nn_ent_idxs = np.zeros((query_len, params['recall_k']))
            nn_ent_dists = np.zeros((query_len, params['recall_k']),
                                    dtype='float64')
            nn_men_idxs = -1 * np.ones((query_len, query_len), dtype=int)
            nn_men_dists = -1 * np.ones(
                (query_len, query_len), dtype='float64')
            for entity_type in men_indexes:
                men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type][
                    men_idxs_by_type[entity_type] >=
                    n_train_mentions]] if params[
                        'transductive'] else men_embeds[
                            men_idxs_by_type[entity_type]]
                nn_ent_dists_by_type, nn_ent_idxs_by_type = dict_indexes[
                    entity_type].search(men_embeds_by_type, params['recall_k'])
                nn_ent_idxs_by_type = np.array(
                    list(
                        map(lambda x: dict_idxs_by_type[entity_type][x],
                            nn_ent_idxs_by_type)))
                n_mens_to_fetch = len(
                    men_embeds_by_type) if within_doc else max_knn + 1
                nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[
                    entity_type].search(
                        men_embeds_by_type,
                        min(n_mens_to_fetch, len(men_embeds_by_type)))
                nn_men_idxs_by_type = np.array(
                    list(
                        map(lambda x: men_idxs_by_type[entity_type][x],
                            nn_men_idxs_by_type)))
                i = -1
                for idx in men_idxs_by_type[entity_type]:
                    if params['transductive']:
                        idx -= n_train_mentions
                    if idx < 0:
                        continue
                    i += 1
                    nn_ent_idxs[idx] = nn_ent_idxs_by_type[i]
                    nn_ent_dists[idx] = nn_ent_dists_by_type[i]
                    nn_men_idxs[idx][:len(nn_men_idxs_by_type[i]
                                          )] = nn_men_idxs_by_type[i]
                    nn_men_dists[idx][:len(nn_men_dists_by_type[i]
                                           )] = nn_men_dists_by_type[i]
        logger.info("Search finished")

        logger.info('Building graphs')
        # Find the most similar entity and k-nn mentions for each mention query
        for idx in range(len(nn_ent_idxs)):
            # Get nearest entity candidate
            dict_cand_idx = nn_ent_idxs[idx][0]
            dict_cand_score = nn_ent_dists[idx][0]
            # Compute recall metric
            gold_idxs = mention_data[idx][
                "label_idxs"][:mention_data[idx]["n_labels"]]
            recall_idx = np.argwhere(nn_ent_idxs[idx] == gold_idxs[0])
            if len(recall_idx) != 0:
                recall_idx = int(recall_idx)
                recall_idxs[recall_idx] += 1.
                for recall_k in recall_accuracy:
                    if recall_idx < recall_k:
                        recall_accuracy[recall_k] += 1.
            if not params['only_recall']:
                filter_mask_neg1 = nn_men_idxs[idx] != -1
                men_cand_idxs = nn_men_idxs[idx][filter_mask_neg1]
                men_cand_scores = nn_men_dists[idx][filter_mask_neg1]

                if within_doc:
                    men_cand_idxs, wd_mask = filter_by_context_doc_id(
                        men_cand_idxs,
                        test_context_doc_ids[idx],
                        test_context_doc_ids,
                        return_numpy=True)
                    men_cand_scores = men_cand_scores[wd_mask]

                # Filter candidates to remove mention query and keep only the top k candidates
                filter_mask = men_cand_idxs != idx
                men_cand_idxs, men_cand_scores = men_cand_idxs[
                    filter_mask][:max_knn], men_cand_scores[
                        filter_mask][:max_knn]

                if params['transductive']:
                    idx += n_train_mentions
                # Add edges to the graphs
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    # Add mention-entity edge
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'],
                        [n_entities + idx
                         ])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(joint_graph['cols'],
                                                    dict_cand_idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    dict_cand_score)
                    if k > 0:
                        # Add mention-mention edges
                        joint_graph['rows'] = np.append(
                            joint_graph['rows'],
                            [n_entities + idx] * len(men_cand_idxs[:k]))
                        joint_graph['cols'] = np.append(
                            joint_graph['cols'],
                            n_entities + men_cand_idxs[:k])
                        joint_graph['data'] = np.append(
                            joint_graph['data'], men_cand_scores[:k])

        if params['transductive']:
            # Add positive infinity mention-entity edges from training queries to labeled entities
            for idx, train_men in enumerate(train_mention_data):
                dict_cand_idx = train_men["label_idxs"][0]
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'],
                        [n_entities + idx
                         ])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(joint_graph['cols'],
                                                    dict_cand_idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    float('inf'))

        # Compute and print recall metric
        recall_idx_mode = np.argmax(recall_idxs)
        recall_idx_mode_prop = recall_idxs[recall_idx_mode] / np.sum(
            recall_idxs)
        logger.info(f"""
        Recall metrics (for {len(nn_ent_idxs)} queries):
        ---------------""")
        logger.info(
            f"highest recall idx = {recall_idx_mode} ({recall_idxs[recall_idx_mode]}/{np.sum(recall_idxs)} = {recall_idx_mode_prop})"
        )
        for recall_k in recall_accuracy:
            recall_accuracy[recall_k] /= len(nn_ent_idxs)
            logger.info(f"recall@{recall_k} = {recall_accuracy[recall_k]}")

        if params['only_recall']:
            exit()

        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

        if params['only_embed_and_build']:
            logger.info(f"Saved embedding data at: {embed_data_path}")
            logger.info(f"Saved graphs at: {graph_path}")
            exit()

    graph_mode = params.get('graph_mode', None)

    result_overview = {
        'n_entities':
        n_entities,
        'n_mentions':
        n_mentions - (n_train_mentions if params['transductive'] else 0)
    }
    results = {}
    if graph_mode is None or graph_mode not in ['directed', 'undirected']:
        results['directed'] = []
        results['undirected'] = []
    else:
        results[graph_mode] = []

    knn_fetch_time = time.time() - time_start
    graph_processing_time = time.time()
    n_graphs_processed = 0.

    for mode in results:
        print(f'\nEvaluation mode: {mode.upper()}')
        for k in joint_graphs:
            if k <= knn:
                print(f"\nGraph (k={k}):")
                # Partition graph based on cluster-linking constraints
                partitioned_graph, clusters = partition_graph(
                    joint_graphs[k],
                    n_entities,
                    mode == 'directed',
                    return_clusters=True)
                # Infer predictions from clusters
                result = analyzeClusters(
                    clusters, test_dictionary, mention_data, k,
                    n_train_mentions if params['transductive'] else 0)
                # Store result
                results[mode].append(result)
                n_graphs_processed += 1

    avg_graph_processing_time = (time.time() -
                                 graph_processing_time) / n_graphs_processed
    avg_per_graph_time = (knn_fetch_time + avg_graph_processing_time) / 60

    execution_time = (time.time() - time_start) / 60
    # Store results
    output_file_name = os.path.join(
        output_path,
        f"eval_results_{__import__('calendar').timegm(__import__('time').gmtime())}"
    )

    try:
        for recall_k in recall_accuracy:
            result_overview[f'recall@{recall_k}'] = recall_accuracy[recall_k]
    except:
        logger.info(
            "Recall data not available since graphs were loaded from disk")

    for mode in results:
        mode_results = results[mode]
        result_overview[mode] = {}
        for r in mode_results:
            k = r['knn_mentions']
            result_overview[mode][f'accuracy@knn{k}'] = r['accuracy']
            logger.info(f"{mode} accuracy@knn{k} = {r['accuracy']}")
            output_file = f'{output_file_name}-{mode}-{k}.json'
            with open(output_file, 'w') as f:
                json.dump(r, f, indent=2)
                print(
                    f"\nPredictions ({mode}) @knn{k} saved at: {output_file}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(result_overview, f, indent=2)
        print(f"\nPredictions overview saved at: {output_file_name}.json")

    logger.info("\nThe avg. per graph evaluation time is {} minutes\n".format(
        avg_per_graph_time))
    logger.info(
        "\nThe total evaluation took {} minutes\n".format(execution_time))