Example #1
0
def main(params):
    time_start = time.time()

    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log-discovery')

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    graph_path = params["graph_path"]
    if graph_path is None or not os.path.exists(graph_path):
        graph_path = output_path

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    rng = np.random.default_rng(seed=17)
    knn = params["knn"]
    use_types = params["use_types"]
    data_split = params["data_split"] # Default = "test"
    graph_mode = params.get('graph_mode', None)

    logger.info(f"Dataset: {data_split.upper()}")

    # Load evaluation data
    entity_dictionary_loaded = False
    dictionary_pkl_path = os.path.join(pickle_src_path, 'test_dictionary.pickle')
    tensor_data_pkl_path = os.path.join(pickle_src_path, 'test_tensor_data.pickle')
    mention_data_pkl_path = os.path.join(pickle_src_path, 'test_mention_data.pickle')
    print("Loading stored processed entity dictionary...")
    with open(dictionary_pkl_path, 'rb') as read_handle:
        dictionary = pickle.load(read_handle)
    print("Loading stored processed mention data...")
    with open(tensor_data_pkl_path, 'rb') as read_handle:
        tensor_data = pickle.load(read_handle)
    with open(mention_data_pkl_path, 'rb') as read_handle:
        mention_data = pickle.load(read_handle)

    # Load stored joint graphs
    graph_path = os.path.join(graph_path, 'graphs.pickle')
    print("Loading stored joint graphs...")
    with open(graph_path, 'rb') as read_handle:
        joint_graphs = pickle.load(read_handle)

    if not params['drop_all_entities']:
        # Since embed data is never used if the above condition is True
        print("Loading embed data...")
        # Check and load stored embedding data
        embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
        embed_data = torch.load(embed_data_path)

    n_entities = len(dictionary)
    seen_mention_idxs = set()
    unseen_mention_idxs_map = {}
    if params["seen_data_path"] is not None:  # Plug data leakage
        with open(params["seen_data_path"], 'rb') as read_handle:
            seen_data = pickle.load(read_handle)
        seen_cui_idxs = set()
        for seen_men in seen_data:
            seen_cui_idxs.add(seen_men['label_idxs'][0])
        logger.info(f"CUIs seen at training: {len(seen_cui_idxs)}")
        filtered_mention_data = []
        for menidx, men in enumerate(mention_data):
            if men['label_idxs'][0] not in seen_cui_idxs:
                filtered_mention_data.append(men)
                unseen_mention_idxs_map[menidx] = len(filtered_mention_data) - 1
            else:
                seen_mention_idxs.add(menidx)
        if not params['no_drop_seen']:
            logger.info("Dropping mentions whose CUIs were seen during training")
            logger.info(f"Unfiltered mention size: {len(mention_data)}")
            mention_data = filtered_mention_data
            logger.info(f"Filtered mention size: {len(mention_data)}")
    n_mentions = len(mention_data)
    n_labels = 1  # Zeshel and MedMentions have single gold entity mentions

    mention_gold_cui_idxs = list(map(lambda x: x['label_idxs'][n_labels - 1], mention_data))
    ents_in_data = np.unique(mention_gold_cui_idxs)

    if params['drop_all_entities']:
        ent_drop_prop = 1
        n_ents_dropped = len(ents_in_data)
        n_mentions_wo_gold_ents = n_mentions
        logger.info(f"Dropping all {n_ents_dropped} entities found in mention set")
        set_dropped_ent_idxs = set()
    else:
        # Percentage of entities from the mention set to drop
        ent_drop_prop = 0.1
        
        logger.info(f"Dropping {ent_drop_prop*100}% of {len(ents_in_data)} entities found in mention set")

        # Get entity indices to drop
        n_ents_dropped = int(ent_drop_prop*len(ents_in_data))
        dropped_ent_idxs = rng.choice(ents_in_data, size=n_ents_dropped, replace=False)
        set_dropped_ent_idxs = set(dropped_ent_idxs)
        
        n_mentions_wo_gold_ents = sum([1 if x in set_dropped_ent_idxs else 0 for x in mention_gold_cui_idxs])
        logger.info(f"Dropped {n_ents_dropped} entities")
        logger.info(f"=> Mentions without gold entities = {n_mentions_wo_gold_ents}")

        # Load embeddings in order to compute new KNN entities after dropping
        print('Computing new dictionary indexes...')
        
        original_dict_embeds = embed_data['dict_embeds']
        keep_mask = np.ones(len(original_dict_embeds), dtype='bool')
        keep_mask[dropped_ent_idxs] = False
        dict_embeds = original_dict_embeds[keep_mask]

        new_to_old_dict_mapping = []
        for i in range(len(original_dict_embeds)):
            if keep_mask[i]:
                new_to_old_dict_mapping.append(i)

        men_embeds = embed_data['men_embeds']
        if use_types:
            dict_idxs_by_type = data_process.get_idxs_by_type(list(compress(dictionary, keep_mask)))
            dict_indexes = data_process.get_index_from_embeds(dict_embeds, dict_idxs_by_type, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor'])
            if 'men_idxs_by_type' in embed_data:
                men_idxs_by_type = embed_data['men_idxs_by_type']
            else:
                men_idxs_by_type = data_process.get_idxs_by_type(mention_data)
        else:
            dict_index = data_process.get_index_from_embeds(dict_embeds, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor'])
        
        # Fetch additional KNN entity to make sure every mention has a linked entity after dropping
        extra_entity_knn = []
        if use_types:
            for men_type in men_idxs_by_type:
                dict_index = dict_indexes[men_type]
                dict_type_idx_mapping = dict_idxs_by_type[men_type]
                q_men_embeds = men_embeds[men_idxs_by_type[men_type]] # np.array(list(map(lambda x: men_embeds[x], men_idxs_by_type[men_type])))
                fetch_k = 1 if isinstance(dict_index, faiss.IndexFlatIP) else 16
                _, nn_idxs = dict_index.search(q_men_embeds, fetch_k)
                for i, men_idx in enumerate(men_idxs_by_type[men_type]):
                    r = n_entities + men_idx
                    q_nn_idxs = dict_type_idx_mapping[nn_idxs[i]]
                    q_nn_embeds = torch.tensor(dict_embeds[q_nn_idxs]).cuda()
                    q_scores = torch.flatten(
                        torch.mm(torch.tensor(q_men_embeds[i:i+1]).cuda(), q_nn_embeds.T)).cpu()
                    c, data = new_to_old_dict_mapping[q_nn_idxs[torch.argmax(q_scores)]], torch.max(q_scores)
                    extra_entity_knn.append((r,c,data))
        else:
            fetch_k = 1 if isinstance(dict_index, faiss.IndexFlatIP) else 16
            _, nn_idxs = dict_index.search(men_embeds, fetch_k)
            for men_idx, men_embed in enumerate(men_embeds):
                r = n_entities + men_idx
                q_nn_idxs = nn_idxs[men_idx]
                q_nn_embeds = torch.tensor(dict_embeds[q_nn_idxs]).cuda()
                q_scores = torch.flatten(
                    torch.mm(torch.tensor(np.expand_dims(men_embed, axis=0)).cuda(), q_nn_embeds.T)).cpu()
                c, data = new_to_old_dict_mapping[q_nn_idxs[torch.argmax(q_scores)]], torch.max(q_scores)
                extra_entity_knn.append((r,c,data))
        
        # Add entities for mentions whose 
        for k in joint_graphs:
            rows, cols, data= [], [], []
            for edge in extra_entity_knn:
                rows.append(edge[0])
                cols.append(edge[1])
                data.append(edge[2])
            joint_graphs[k]['rows'] = np.concatenate((joint_graphs[k]['rows'], rows))
            joint_graphs[k]['cols'] = np.concatenate((joint_graphs[k]['cols'], cols))
            joint_graphs[k]['data'] = np.concatenate((joint_graphs[k]['data'], data))

    results = {
        'data_split': data_split.upper(),
        'n_entities': n_entities,
        'n_mentions': n_mentions,
        'n_entities_dropped': f"{n_ents_dropped} ({ent_drop_prop*100}%)",
        'n_mentions_wo_gold_entities': n_mentions_wo_gold_ents
    }
    if graph_mode is None or graph_mode not in ['directed', 'undirected']:
        graph_mode = ['directed', 'undirected']
    else:
        graph_mode = [graph_mode]

    n_thresholds = params['n_thresholds'] # Default is 10
    exact_threshold = params.get('exact_threshold', None)
    exact_knn = params.get('exact_knn', None)
    
    kmeans = KMeans(n_clusters=n_thresholds, random_state=17)

    # TODO: Baseline? (without dropping entities)

    for mode in graph_mode:
        best_result = -1.
        best_config = None
        for k in joint_graphs:
            if params['drop_all_entities']:
                # Drop all entities from the graph
                rows, cols, data = joint_graphs[k]['rows'], joint_graphs[k]['cols'], joint_graphs[k]['data']
                _f_row, _f_col, _f_data = [], [], []
                for ki in range(len(joint_graphs[k]['rows'])):
                    if joint_graphs[k]['cols'][ki] < n_entities or joint_graphs[k]['rows'][ki] < n_entities:
                        continue
                    # Remove mentions whose gold entity was seen during training
                    if len(seen_mention_idxs) > 0 and not params['no_drop_seen']:
                        if (joint_graphs[k]['rows'][ki] - n_entities) in seen_mention_idxs or \
                                (joint_graphs[k]['cols'][ki] - n_entities) in seen_mention_idxs:
                            continue
                    _f_row.append(joint_graphs[k]['rows'][ki])
                    _f_col.append(joint_graphs[k]['cols'][ki])
                    _f_data.append(joint_graphs[k]['data'][ki])
                joint_graphs[k]['rows'], joint_graphs[k]['cols'], joint_graphs[k]['data'] = list(map(np.array, (_f_row, _f_col, _f_data)))
            if (exact_knn is None and k > 0 and k <= knn) or (exact_knn is not None and k == exact_knn):
                if exact_threshold is not None:
                    thresholds = np.array([0, exact_threshold])
                else:
                    thresholds = np.sort(np.concatenate(([0], kmeans.fit(joint_graphs[k]['data'].reshape(-1,1)).cluster_centers_.flatten())))
                for thresh in thresholds:
                    print("\nPartitioning...")
                    logger.info(f"{mode.upper()}, k={k}, threshold={thresh}")
                    # Partition graph based on cluster-linking constraints
                    partitioned_graph, clusters = partition_graph(
                        joint_graphs[k], n_entities, mode == 'directed', return_clusters=True, exclude=set_dropped_ent_idxs, threshold=thresh, without_entities=params['drop_all_entities'])
                    # Analyze cluster against gold clusters
                    result = analyzeClusters(clusters, mention_gold_cui_idxs, n_entities, n_mentions, logger, unseen_mention_idxs_map, no_drop_seen=params['no_drop_seen'])
                    results[f'({mode}, {k}, {thresh})'] = result
                    if not params['no_drop_seen']:
                        if thresh != 0 and result['average'] > best_result:
                            best_result = result['average']
                            best_config = (mode, k, thresh)
        if not params['no_drop_seen']:
            results[f'best_{mode}_config'] = best_config
            results[f'best_{mode}_result'] = best_result
    
    # Store results
    output_file_name = os.path.join(
        output_path, f"{data_split}_eval_discovery_{__import__('calendar').timegm(__import__('time').gmtime())}.json")

    with open(output_file_name, 'w') as f:
        json.dump(results, f, indent=2)
        print(f"\nAnalysis saved at: {output_file_name}")
    execution_time = (time.time() - time_start) / 60
    logger.info(f"\nTotal time taken: {execution_time} minutes\n")
Example #2
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log-eval')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    cands_path = os.path.join(output_path, 'cands.pickle')

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    use_types = params["use_types"]
    within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    test_samples = None
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    if params['transductive']:
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_mention_data_pkl_path = os.path.join(
            pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(
                filter(
                    lambda sample: (len(sample["labels"]) > 0)
                    if mult_labels else (sample["label"] is not None),
                    test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            entity_dictionary_loaded = True
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(list(map(lambda x: x['ids'],
                                           test_dictionary)),
                                  dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    if within_doc:
        if test_samples is None:
            test_samples, _ = read_data(data_split, params, logger)
        test_context_doc_ids = [s['context_doc_id'] for s in test_samples]

    if params["transductive"]:
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_mention_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_mention_data_pkl_path, 'rb') as read_handle:
                train_mention_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset('train', params["data_path"])

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            logger.info("Read %d test samples." % len(test_samples))

            train_mention_data, _, train_tensor_data = data_process.process_mention_data(
                train_samples,
                test_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                multi_label_key="labels" if mult_labels else None,
                context_key=params["context_key"],
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_mention_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_mention_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store train mention token ids
        train_men_vecs = train_tensor_data[:][0]
        n_mentions += len(train_tensor_data)
        n_train_mentions = len(train_tensor_data)

    # if os.path.isfile(cands_path):
    #     print("Loading stored candidates...")
    #     with open(cands_path, 'rb') as read_handle:
    #         men_cands = pickle.load(read_handle)
    # else:
    # Check and load stored embedding data
    embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
    embed_data = None
    if os.path.isfile(embed_data_path):
        embed_data = torch.load(embed_data_path)

    if use_types:
        if embed_data is not None:
            logger.info('Loading stored embeddings and computing indexes')
            dict_embeds = embed_data['dict_embeds']
            if 'dict_idxs_by_type' in embed_data:
                dict_idxs_by_type = embed_data['dict_idxs_by_type']
            else:
                dict_idxs_by_type = data_process.get_idxs_by_type(
                    test_dictionary)
            dict_indexes = data_process.get_index_from_embeds(
                dict_embeds,
                dict_idxs_by_type,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
            men_embeds = embed_data['men_embeds']
            if 'men_idxs_by_type' in embed_data:
                men_idxs_by_type = embed_data['men_idxs_by_type']
            else:
                men_idxs_by_type = data_process.get_idxs_by_type(mention_data)
            men_indexes = data_process.get_index_from_embeds(
                men_embeds,
                men_idxs_by_type,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
        else:
            logger.info("Dictionary: Embedding and building index")
            dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                reranker,
                test_dict_vecs,
                encoder_type="candidate",
                n_gpu=n_gpu,
                corpus=test_dictionary,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
            logger.info("Queries: Embedding and building index")
            vecs = test_men_vecs
            men_data = mention_data
            if params['transductive']:
                vecs = torch.cat((train_men_vecs, vecs), dim=0)
                men_data = train_mention_data + mention_data
            men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
                reranker,
                vecs,
                encoder_type="context",
                n_gpu=n_gpu,
                corpus=men_data,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
    else:
        if embed_data is not None:
            logger.info('Loading stored embeddings and computing indexes')
            dict_embeds = embed_data['dict_embeds']
            dict_index = data_process.get_index_from_embeds(
                dict_embeds,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
            men_embeds = embed_data['men_embeds']
            men_index = data_process.get_index_from_embeds(
                men_embeds,
                force_exact_search=params['force_exact_search'],
                probe_mult_factor=params['probe_mult_factor'])
        else:
            logger.info("Dictionary: Embedding and building index")
            dict_embeds, dict_index = data_process.embed_and_index(
                reranker,
                test_dict_vecs,
                'candidate',
                n_gpu=n_gpu,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])
            logger.info("Queries: Embedding and building index")
            vecs = test_men_vecs
            if params['transductive']:
                vecs = torch.cat((train_men_vecs, vecs), dim=0)
            men_embeds, men_index = data_process.embed_and_index(
                reranker,
                vecs,
                'context',
                n_gpu=n_gpu,
                force_exact_search=params['force_exact_search'],
                batch_size=params['embed_batch_size'],
                probe_mult_factor=params['probe_mult_factor'])

    # Save computed embedding data if not loaded from disk
    if embed_data is None:
        embed_data = {}
        embed_data['dict_embeds'] = dict_embeds
        embed_data['men_embeds'] = men_embeds
        if use_types:
            embed_data['dict_idxs_by_type'] = dict_idxs_by_type
            embed_data['men_idxs_by_type'] = men_idxs_by_type
        # NOTE: Cannot pickle faiss index because it is a SwigPyObject
        torch.save(embed_data,
                   embed_data_path,
                   pickle_protocol=pickle.HIGHEST_PROTOCOL)

    logger.info("Starting KNN search...")
    # Fetch (k+1) NN mention candidates
    if not use_types:
        _men_embeds = men_embeds
        if params['transductive']:
            _men_embeds = _men_embeds[n_train_mentions:]
        n_mens_to_fetch = len(_men_embeds) if within_doc else knn + 1
        nn_men_dists, nn_men_idxs = men_index.search(_men_embeds,
                                                     n_mens_to_fetch)
    else:
        query_len = len(men_embeds) - (n_train_mentions
                                       if params['transductive'] else 0)
        nn_men_idxs = -1 * np.ones((query_len, query_len), dtype=int)
        nn_men_dists = -1 * np.ones((query_len, query_len), dtype='float64')
        for entity_type in men_indexes:
            men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type][
                men_idxs_by_type[entity_type] >=
                n_train_mentions]] if params['transductive'] else men_embeds[
                    men_idxs_by_type[entity_type]]
            n_mens_to_fetch = len(
                men_embeds_by_type) if within_doc else knn + 1
            nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[
                entity_type].search(
                    men_embeds_by_type,
                    min(n_mens_to_fetch, len(men_embeds_by_type)))
            nn_men_idxs_by_type = np.array(
                list(
                    map(lambda x: men_idxs_by_type[entity_type][x],
                        nn_men_idxs_by_type)))
            i = -1
            for idx in men_idxs_by_type[entity_type]:
                if params['transductive']:
                    idx -= n_train_mentions
                if idx < 0:
                    continue
                i += 1
                nn_men_idxs[idx][:len(nn_men_idxs_by_type[i]
                                      )] = nn_men_idxs_by_type[i]
                nn_men_dists[idx][:len(nn_men_dists_by_type[i]
                                       )] = nn_men_dists_by_type[i]
    logger.info("Search finished")

    logger.info(f'Calculating mention recall@{knn}')
    # Find the most similar entity and k-nn mentions for each mention query
    men_recall_knn = []
    men_cands = []
    cui_sums = defaultdict(int)
    for m in mention_data:
        cui_sums[m['label_cuis'][0]] += 1
    for idx in range(len(nn_men_idxs)):
        filter_mask_neg1 = nn_men_idxs[idx] != -1
        men_cand_idxs = nn_men_idxs[idx][filter_mask_neg1]
        men_cand_scores = nn_men_dists[idx][filter_mask_neg1]

        if within_doc:
            men_cand_idxs, wd_mask = filter_by_context_doc_id(
                men_cand_idxs,
                test_context_doc_ids[idx],
                test_context_doc_ids,
                return_numpy=True)
            men_cand_scores = men_cand_scores[wd_mask]

        # Filter candidates to remove mention query and keep only the top k candidates
        filter_mask = men_cand_idxs != idx
        men_cand_idxs, men_cand_scores = men_cand_idxs[
            filter_mask][:knn], men_cand_scores[filter_mask][:knn]
        assert len(men_cand_idxs) == knn
        men_cands.append(men_cand_idxs)
        # Calculate mention recall@k
        gold_cui = mention_data[idx]['label_cuis'][0]
        if cui_sums[gold_cui] > 1:
            recall_hit = [
                1. for midx in men_cand_idxs
                if mention_data[midx]['label_cuis'][0] == gold_cui
            ]
            recall_knn = sum(recall_hit) / min(cui_sums[gold_cui] - 1, knn)
            men_recall_knn.append(recall_knn)
    logger.info('Done')
    # assert len(men_recall_knn) == len(nn_men_idxs)

    # Pickle the graphs
    print(f"Saving top-{knn} candidates")
    with open(cands_path, 'wb') as write_handle:
        pickle.dump(men_cands, write_handle, protocol=pickle.HIGHEST_PROTOCOL)

    # Output final results
    mean_recall_knn = np.mean(men_recall_knn)
    logger.info(f"Result: Final mean recall@{knn} = {mean_recall_knn}")
Example #3
0
def main(params):
    model_output_path = params["output_path"]
    if not os.path.exists(model_output_path):
        os.makedirs(model_output_path)
    logger = utils.get_logger(params["output_path"])

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = model_output_path

    knn = params["knn"]
    use_types = params["use_types"]

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    device = reranker.device
    n_gpu = reranker.n_gpu

    if params["gradient_accumulation_steps"] < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(params["gradient_accumulation_steps"]))

    # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
    params["train_batch_size"] = (params["train_batch_size"] //
                                  params["gradient_accumulation_steps"])
    train_batch_size = params["train_batch_size"]
    eval_batch_size = params["eval_batch_size"]
    grad_acc_steps = params["gradient_accumulation_steps"]

    # Fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    entity_dictionary_loaded = False
    entity_dictionary_pkl_path = os.path.join(pickle_src_path,
                                              'entity_dictionary.pickle')
    if os.path.isfile(entity_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(entity_dictionary_pkl_path, 'rb') as read_handle:
            entity_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if not params["only_evaluate"]:
        # Load train data
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_processed_data_pkl_path = os.path.join(
            pickle_src_path, 'train_processed_data.pickle')
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_processed_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_processed_data_pkl_path, 'rb') as read_handle:
                train_processed_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset("train", params["data_path"])
            if not entity_dictionary_loaded:
                with open(
                        os.path.join(params["data_path"], 'dictionary.pickle'),
                        'rb') as read_handle:
                    entity_dictionary = pickle.load(read_handle)

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            if params["filter_unlabeled"]:
                # Filter samples without gold entities
                train_samples = list(
                    filter(
                        lambda sample: (len(sample["labels"]) > 0)
                        if mult_labels else (sample["label"] is not None),
                        train_samples))
            logger.info("Read %d train samples." % len(train_samples))

            # For discovery experiment: Drop entities used in training that were dropped randomly from dev/test set
            if params["drop_entities"]:
                assert entity_dictionary_loaded
                drop_set_pkl_path = os.path.join(
                    pickle_src_path, 'drop_set_mention_data.pickle')
                with open(drop_set_pkl_path, 'rb') as read_handle:
                    drop_set_data = pickle.load(read_handle)
                drop_set_mention_gold_cui_idxs = list(
                    map(lambda x: x['label_idxs'][0], drop_set_data))
                ents_in_data = np.unique(drop_set_mention_gold_cui_idxs)
                ent_drop_prop = 0.1
                logger.info(
                    f"Dropping {ent_drop_prop*100}% of {len(ents_in_data)} entities found in drop set"
                )
                # Get entity indices to drop
                n_ents_dropped = int(ent_drop_prop * len(ents_in_data))
                rng = np.random.default_rng(seed=17)
                dropped_ent_idxs = rng.choice(ents_in_data,
                                              size=n_ents_dropped,
                                              replace=False)

                # Drop entities from dictionary (subsequent processing will automatically drop corresponding mentions)
                keep_mask = np.ones(len(entity_dictionary), dtype='bool')
                keep_mask[dropped_ent_idxs] = False
                entity_dictionary = np.array(entity_dictionary)[keep_mask]

            train_processed_data, entity_dictionary, train_tensor_data = data_process.process_mention_data(
                train_samples,
                entity_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                context_key=params["context_key"],
                multi_label_key="labels" if mult_labels else None,
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            if not entity_dictionary_loaded:
                with open(entity_dictionary_pkl_path, 'wb') as write_handle:
                    pickle.dump(entity_dictionary,
                                write_handle,
                                protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_processed_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_processed_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store the query mention vectors
        train_men_vecs = train_tensor_data[:][0]

        if params["shuffle"]:
            train_sampler = RandomSampler(train_tensor_data)
        else:
            train_sampler = SequentialSampler(train_tensor_data)

        train_dataloader = DataLoader(train_tensor_data,
                                      sampler=train_sampler,
                                      batch_size=train_batch_size)

    # Store the entity dictionary vectors
    entity_dict_vecs = torch.tensor(list(
        map(lambda x: x['ids'], entity_dictionary)),
                                    dtype=torch.long)

    # Load eval data
    valid_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                              'valid_tensor_data.pickle')
    valid_processed_data_pkl_path = os.path.join(
        pickle_src_path, 'valid_processed_data.pickle')
    if os.path.isfile(valid_tensor_data_pkl_path) and os.path.isfile(
            valid_processed_data_pkl_path):
        print("Loading stored processed valid data...")
        with open(valid_tensor_data_pkl_path, 'rb') as read_handle:
            valid_tensor_data = pickle.load(read_handle)
        with open(valid_processed_data_pkl_path, 'rb') as read_handle:
            valid_processed_data = pickle.load(read_handle)
    else:
        valid_samples = utils.read_dataset("valid", params["data_path"])
        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in valid_samples[0].keys()
        # Filter samples without gold entities
        valid_samples = list(
            filter(
                lambda sample: (len(sample["labels"]) > 0) if mult_labels else
                (sample["label"] is not None), valid_samples))
        logger.info("Read %d valid samples." % len(valid_samples))

        valid_processed_data, _, valid_tensor_data = data_process.process_mention_data(
            valid_samples,
            entity_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            context_key=params["context_key"],
            multi_label_key="labels" if mult_labels else None,
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=True)
        print("Saving processed valid data...")
        with open(valid_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(valid_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(valid_processed_data_pkl_path, 'wb') as write_handle:
            pickle.dump(valid_processed_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store the query mention vectors
    valid_men_vecs = valid_tensor_data[:][0]

    # valid_sampler = SequentialSampler(valid_tensor_data)
    # valid_dataloader = DataLoader(
    #     valid_tensor_data, sampler=valid_sampler, batch_size=eval_batch_size
    # )

    if params["only_evaluate"]:
        evaluate(reranker,
                 entity_dict_vecs,
                 valid_men_vecs,
                 device=device,
                 logger=logger,
                 knn=knn,
                 n_gpu=n_gpu,
                 entity_data=entity_dictionary,
                 query_data=valid_processed_data,
                 silent=params["silent"],
                 use_types=use_types or params["use_types_for_eval"],
                 embed_batch_size=params['embed_batch_size'],
                 force_exact_search=use_types or params["use_types_for_eval"]
                 or params["force_exact_search"],
                 probe_mult_factor=params['probe_mult_factor'])
        exit()

    time_start = time.time()
    utils.write_to_file(os.path.join(model_output_path, "training_params.txt"),
                        str(params))
    logger.info("Starting training")
    logger.info("device: {} n_gpu: {}, data_parallel: {}".format(
        device, n_gpu, params["data_parallel"]))

    # Set model to training mode
    optimizer = get_optimizer(model, params)
    scheduler = get_scheduler(params, optimizer, len(train_tensor_data),
                              logger)
    best_epoch_idx = -1
    best_score = -1
    num_train_epochs = params["num_train_epochs"]

    init_base_model_run = True if params.get("path_to_model",
                                             None) is None else False
    init_run_pkl_path = os.path.join(
        pickle_src_path, f'init_run_{"type" if use_types else "notype"}.t7')

    dict_embed_data = None

    for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):
        model.train()
        torch.cuda.empty_cache()
        tr_loss = 0
        results = None

        # Check if embeddings and index can be loaded
        init_run_data_loaded = False
        if init_base_model_run:
            if os.path.isfile(init_run_pkl_path):
                logger.info('Loading init run data')
                init_run_data = torch.load(init_run_pkl_path)
                init_run_data_loaded = True
        load_stored_data = init_base_model_run and init_run_data_loaded

        # Compute mention and entity embeddings at the start of each epoch
        if use_types:
            if load_stored_data:
                train_dict_embeddings, dict_idxs_by_type = init_run_data[
                    'train_dict_embeddings'], init_run_data[
                        'dict_idxs_by_type']
                train_dict_indexes = data_process.get_index_from_embeds(
                    train_dict_embeddings,
                    dict_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, men_idxs_by_type = init_run_data[
                    'train_men_embeddings'], init_run_data['men_idxs_by_type']
                train_men_indexes = data_process.get_index_from_embeds(
                    train_men_embeddings,
                    men_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info('Embedding and indexing')
                if dict_embed_data is not None:
                    train_dict_embeddings, train_dict_indexes, dict_idxs_by_type = dict_embed_data[
                        'dict_embeds'], dict_embed_data[
                            'dict_indexes'], dict_embed_data[
                                'dict_idxs_by_type']
                else:
                    train_dict_embeddings, train_dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                        reranker,
                        entity_dict_vecs,
                        encoder_type="candidate",
                        n_gpu=n_gpu,
                        corpus=entity_dictionary,
                        force_exact_search=params['force_exact_search'],
                        batch_size=params['embed_batch_size'],
                        probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, train_men_indexes, men_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    train_men_vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    corpus=train_processed_data,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
        else:
            if load_stored_data:
                train_dict_embeddings = init_run_data['train_dict_embeddings']
                train_dict_index = data_process.get_index_from_embeds(
                    train_dict_embeddings,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings = init_run_data['train_men_embeddings']
                train_men_index = data_process.get_index_from_embeds(
                    train_men_embeddings,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info('Embedding and indexing')
                if dict_embed_data is not None:
                    train_dict_embeddings, train_dict_index = dict_embed_data[
                        'dict_embeds'], dict_embed_data['dict_index']
                else:
                    train_dict_embeddings, train_dict_index = data_process.embed_and_index(
                        reranker,
                        entity_dict_vecs,
                        encoder_type="candidate",
                        n_gpu=n_gpu,
                        force_exact_search=params['force_exact_search'],
                        batch_size=params['embed_batch_size'],
                        probe_mult_factor=params['probe_mult_factor'])
                train_men_embeddings, train_men_index = data_process.embed_and_index(
                    reranker,
                    train_men_vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])

        # Save the initial embeddings and index if this is the first run and data isn't persistent
        if init_base_model_run and not load_stored_data:
            init_run_data = {}
            init_run_data['train_dict_embeddings'] = train_dict_embeddings
            init_run_data['train_men_embeddings'] = train_men_embeddings
            if use_types:
                init_run_data['dict_idxs_by_type'] = dict_idxs_by_type
                init_run_data['men_idxs_by_type'] = men_idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(init_run_data,
                       init_run_pkl_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        init_base_model_run = False

        if params["silent"]:
            iter_ = train_dataloader
        else:
            iter_ = tqdm(train_dataloader, desc="Batch")

        logger.info("Starting KNN search...")
        if not use_types:
            _, dict_nns = train_dict_index.search(train_men_embeddings, knn)
        else:
            dict_nns = np.zeros((len(train_men_embeddings), knn))
            for entity_type in train_men_indexes:
                men_embeds_by_type = train_men_embeddings[
                    men_idxs_by_type[entity_type]]
                _, dict_nns_by_type = train_dict_indexes[entity_type].search(
                    men_embeds_by_type, knn)
                dict_nns_idxs = np.array(
                    list(
                        map(lambda x: dict_idxs_by_type[entity_type][x],
                            dict_nns_by_type)))
                for i, idx in enumerate(men_idxs_by_type[entity_type]):
                    dict_nns[idx] = dict_nns_idxs[i]
        logger.info("Search finished")

        for step, batch in enumerate(iter_):
            batch = tuple(t.to(device) for t in batch)
            context_inputs, candidate_idxs, n_gold, mention_idxs = batch
            mention_embeddings = train_men_embeddings[mention_idxs.cpu()]

            if len(mention_embeddings.shape) == 1:
                mention_embeddings = np.expand_dims(mention_embeddings, axis=0)

            # context_inputs: Shape: batch x token_len
            candidate_inputs = np.array(
                [], dtype=np.int)  # Shape: (batch*knn) x token_len
            label_inputs = torch.tensor(
                [[1] + [0] * (knn - 1)] * n_gold.sum(),
                dtype=torch.float32)  # Shape: batch(with split rows) x knn
            context_inputs_split = torch.zeros(
                (label_inputs.size(0), context_inputs.size(1)),
                dtype=torch.long)  # Shape: batch(with split rows) x token_len
            # label_inputs = (candidate_idxs >= 0).type(torch.float32) # Shape: batch x knn

            for i, m_embed in enumerate(mention_embeddings):
                knn_dict_idxs = dict_nns[mention_idxs[i]]
                knn_dict_idxs = knn_dict_idxs.astype(np.int64).flatten()
                gold_idxs = candidate_idxs[i][:n_gold[i]].cpu()
                for ng, gold_idx in enumerate(gold_idxs):
                    context_inputs_split[i + ng] = context_inputs[i]
                    candidate_inputs = np.concatenate(
                        (candidate_inputs,
                         np.concatenate(
                             ([gold_idx],
                              knn_dict_idxs[~np.isin(knn_dict_idxs, gold_idxs)]
                              ))[:knn]))
            candidate_inputs = torch.tensor(
                list(
                    map(lambda x: entity_dict_vecs[x].numpy(),
                        candidate_inputs))).cuda()
            context_inputs_split = context_inputs_split.cuda()
            label_inputs = label_inputs.cuda()

            loss, _ = reranker(context_inputs_split,
                               candidate_inputs,
                               label_inputs,
                               pos_neg_loss=params["pos_neg_loss"])

            if grad_acc_steps > 1:
                loss = loss / grad_acc_steps

            tr_loss += loss.item()

            if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
                logger.info("Step {} - epoch {} average loss: {}\n".format(
                    step,
                    epoch_idx,
                    tr_loss / (params["print_interval"] * grad_acc_steps),
                ))
                tr_loss = 0

            loss.backward()

            if (step + 1) % grad_acc_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               params["max_grad_norm"])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
                logger.info("Evaluation on the development dataset")
                evaluate(reranker,
                         entity_dict_vecs,
                         valid_men_vecs,
                         device=device,
                         logger=logger,
                         knn=knn,
                         n_gpu=n_gpu,
                         entity_data=entity_dictionary,
                         query_data=valid_processed_data,
                         silent=params["silent"],
                         use_types=use_types or params["use_types_for_eval"],
                         embed_batch_size=params['embed_batch_size'],
                         force_exact_search=use_types
                         or params["use_types_for_eval"]
                         or params["force_exact_search"],
                         probe_mult_factor=params['probe_mult_factor'])
                model.train()
                logger.info("\n")

        logger.info("***** Saving fine-tuned model *****")
        epoch_output_folder_path = os.path.join(model_output_path,
                                                "epoch_{}".format(epoch_idx))
        utils.save_model(model, tokenizer, epoch_output_folder_path)
        logger.info(f"Model saved at {epoch_output_folder_path}")

        normalized_accuracy, dict_embed_data = evaluate(
            reranker,
            entity_dict_vecs,
            valid_men_vecs,
            device=device,
            logger=logger,
            knn=knn,
            n_gpu=n_gpu,
            entity_data=entity_dictionary,
            query_data=valid_processed_data,
            silent=params["silent"],
            use_types=use_types or params["use_types_for_eval"],
            embed_batch_size=params['embed_batch_size'],
            force_exact_search=use_types or params["use_types_for_eval"]
            or params["force_exact_search"],
            probe_mult_factor=params['probe_mult_factor'])

        ls = [best_score, normalized_accuracy]
        li = [best_epoch_idx, epoch_idx]

        best_score = ls[np.argmax(ls)]
        best_epoch_idx = li[np.argmax(ls)]
        logger.info("\n")

    execution_time = (time.time() - time_start) / 60
    utils.write_to_file(
        os.path.join(model_output_path, "training_time.txt"),
        "The training took {} minutes\n".format(execution_time),
    )
    logger.info("The training took {} minutes\n".format(execution_time))

    # save the best model in the parent_dir
    logger.info("Best performance in epoch: {}".format(best_epoch_idx))
    params["path_to_model"] = os.path.join(model_output_path,
                                           "epoch_{}".format(best_epoch_idx))
    utils.save_model(reranker.model, tokenizer, model_output_path)
    logger.info(f"Best model saved at {model_output_path}")
Example #4
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]  # Use as the max-knn value for the graph construction
    use_types = params["use_types"]
    # within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    # if params['transductive']:
    #     train_tensor_data_pkl_path = os.path.join(pickle_src_path, 'train_tensor_data.pickle')
    #     train_mention_data_pkl_path = os.path.join(pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        # Filter samples without gold entities
        test_samples = list(
            filter(
                lambda sample: (len(sample["labels"]) > 0) if mult_labels else
                (sample["label"] is not None), test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Reducing the entity dictionary to only the ground truth of the mention queries
    # Combining the entities and mentions into one structure for joint embedding and indexing
    new_ents = {}
    new_ents_arr = []
    men_labels = []
    for men in mention_data:
        ent = men['label_idxs'][0]
        if ent not in new_ents:
            new_ents[ent] = len(new_ents_arr)
            new_ents_arr.append(ent)
        men_labels.append(new_ents[ent])
    ent_labels = [i for i in range(len(new_ents_arr))]
    new_ent_vecs = torch.tensor(
        list(map(lambda x: test_dictionary[x]['ids'], new_ents_arr)))
    new_ent_types = list(
        map(lambda x: {"type": test_dictionary[x]['type']}, new_ents_arr))
    test_men_vecs = test_tensor_data[:][0]

    n_mentions = len(test_tensor_data)
    n_entities = len(new_ent_vecs)
    n_embeds = n_mentions + n_entities
    leaf_labels = np.array(ent_labels + men_labels, dtype=int)
    all_vecs = torch.cat((new_ent_vecs, test_men_vecs))
    all_types = new_ent_types + mention_data  # Array of dicts containing key "type" for selected ents and all mentions

    # Values of k to run the evaluation against
    knn_vals = [25 * 2**i for i in range(int(math.log(knn / 25, 2)) + 1)
                ] if params["exact_knn"] is None else [params["exact_knn"]]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    time_start = time.time()

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_embeds, n_embeds)
            }

        # Check and load stored embedding data
        embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
        embed_data = None
        if os.path.isfile(embed_data_path):
            embed_data = torch.load(embed_data_path)
        if use_types:
            if embed_data is not None:
                logger.info('Loading stored embeddings')
                embeds = embed_data['embeds']
                if 'idxs_by_type' in embed_data:
                    idxs_by_type = embed_data['idxs_by_type']
                else:
                    idxs_by_type = data_process.get_idxs_by_type(all_types)
            else:
                logger.info("Embedding data")
                dict_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[:n_entities],
                    encoder_type='candidate',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                men_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[n_entities:],
                    encoder_type='context',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                embeds = np.concatenate((dict_embeds, men_embeds), axis=0)
                idxs_by_type = data_process.get_idxs_by_type(all_types)
            search_indexes = data_process.get_index_from_embeds(
                embeds, corpus_idxs=idxs_by_type, force_exact_search=True)
        else:
            if embed_data is not None:
                logger.info('Loading stored embeddings')
                embeds = embed_data['embeds']
            else:
                logger.info("Embedding data")
                dict_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[:n_entities],
                    encoder_type='candidate',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                men_embeds = data_process.embed_and_index(
                    reranker,
                    all_vecs[n_entities:],
                    encoder_type='context',
                    only_embed=True,
                    n_gpu=n_gpu,
                    batch_size=params['embed_batch_size'])
                embeds = np.concatenate((dict_embeds, men_embeds), axis=0)
            search_index = data_process.get_index_from_embeds(
                embeds, force_exact_search=True)
        # Save computed embedding data if not loaded from disk
        if embed_data is None:
            embed_data = {}
            embed_data['embeds'] = embeds
            if use_types:
                embed_data['idxs_by_type'] = idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(embed_data,
                       embed_data_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        # Build faiss search index
        if params["normalize_embeds"]:
            embeds = normalize(embeds, axis=0)
        logger.info("Building KNN index...")
        if use_types:
            search_indexes = data_process.get_index_from_embeds(
                embeds, corpus_idxs=idxs_by_type, force_exact_search=True)
        else:
            search_index = data_process.get_index_from_embeds(
                embeds, force_exact_search=True)

        logger.info("Starting KNN search...")
        if not use_types:
            faiss_dists, faiss_idxs = search_index.search(embeds, max_knn + 1)
        else:
            query_len = n_embeds
            faiss_idxs = np.zeros((query_len, max_knn + 1))
            faiss_dists = np.zeros((query_len, max_knn + 1), dtype=float)
            for entity_type in search_indexes:
                embeds_by_type = embeds[idxs_by_type[entity_type]]
                nn_dists_by_type, nn_idxs_by_type = search_indexes[
                    entity_type].search(embeds_by_type, max_knn + 1)
                for i, idx in enumerate(idxs_by_type[entity_type]):
                    faiss_idxs[idx] = nn_idxs_by_type[i]
                    faiss_dists[idx] = nn_dists_by_type[i]
        logger.info("Search finished")

        logger.info('Building graphs')
        # Find the most similar nodes for each mention and node in the set (minus self)
        for idx in trange(n_embeds):
            # Compute adjacent node edge weight
            if idx != 0:
                adj_idx = idx - 1
                adj_data = embeds[adj_idx] @ embeds[idx]
            nn_idxs = faiss_idxs[idx]
            nn_scores = faiss_dists[idx]
            # Filter candidates to remove mention query and keep only the top k candidates
            filter_mask = nn_idxs != idx
            nn_idxs, nn_scores = nn_idxs[filter_mask][:max_knn], nn_scores[
                filter_mask][:max_knn]
            # Add edges to the graphs
            for k in joint_graphs:
                # Add edge to adjacent node to force the graph to be connected
                if idx != 0:
                    joint_graph['rows'] = np.append(joint_graph['rows'],
                                                    adj_idx)
                    joint_graph['cols'] = np.append(joint_graph['cols'], idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    adj_data)
                joint_graph = joint_graphs[k]
                # Add mention-mention edges
                joint_graph['rows'] = np.append(joint_graph['rows'], [idx] * k)
                joint_graph['cols'] = np.append(joint_graph['cols'],
                                                nn_idxs[:k])
                joint_graph['data'] = np.append(joint_graph['data'],
                                                nn_scores[:k])

        knn_fetch_time = time.time() - time_start
        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

        if params['only_embed_and_build']:
            logger.info(f"Saved embedding data at: {embed_data_path}")
            logger.info(f"Saved graphs at: {graph_path}")
            exit()

    results = {
        'n_leaves': n_embeds,
        'n_entities': n_entities,
        'n_mentions': n_mentions
    }

    graph_processing_time = time.time()
    n_graphs_processed = 0.
    linkage_fns = ["single", "complete", "average"] if params["linkage"] is None \
        else [params["linkage"]]  # Different HAC linkage functions to run the analyses over

    for fn in linkage_fns:
        logger.info(f"Linkage function: {fn}")
        purities = []
        fn_result = {}
        for k in joint_graphs:
            graph = hg.UndirectedGraph(n_embeds)
            graph.add_edges(joint_graphs[k]['rows'], joint_graphs[k]['cols'])
            weights = -joint_graphs[k][
                'data']  # Since Higra expects weights as distances, not similarity
            tree = get_hac_tree(graph, weights, linkage=fn)
            purity = hg.dendrogram_purity(tree, leaf_labels)
            fn_result[f"purity@{k}nn"] = purity
            logger.info(f"purity@{k}nn = {purity}")
            purities.append(purity)
            n_graphs_processed += 1
        fn_result["average"] = round(np.mean(purities), 4)
        logger.info(f"average = {fn_result['average']}")
        results[fn] = fn_result

    avg_graph_processing_time = (time.time() -
                                 graph_processing_time) / n_graphs_processed
    avg_per_graph_time = (knn_fetch_time + avg_graph_processing_time) / 60
    execution_time = (time.time() - time_start) / 60

    # Store results
    output_file_name = os.path.join(
        output_path,
        f"results_{__import__('calendar').timegm(__import__('time').gmtime())}"
    )

    logger.info(f"Results: \n {results}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(results, f, indent=2)
        print(f"\nResults saved at: {output_file_name}.json")

    logger.info("\nThe avg. per graph evaluation time is {} minutes\n".format(
        avg_per_graph_time))
    logger.info(
        "\nThe total evaluation took {} minutes\n".format(execution_time))
Example #5
0
def main(params):
    output_path = params["output_path"]
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = utils.get_logger(params["output_path"], 'log-eval')

    pickle_src_path = params["pickle_src_path"]
    if pickle_src_path is None or not os.path.exists(pickle_src_path):
        pickle_src_path = output_path

    embed_data_path = params["embed_data_path"]
    if embed_data_path is None or not os.path.exists(embed_data_path):
        embed_data_path = output_path

    # Init model
    reranker = BiEncoderRanker(params)
    reranker.model.eval()
    tokenizer = reranker.tokenizer
    n_gpu = reranker.n_gpu

    knn = params["knn"]
    use_types = params["use_types"]
    within_doc = params["within_doc"]
    data_split = params["data_split"]  # Default = "test"

    # Load test data
    test_samples = None
    entity_dictionary_loaded = False
    test_dictionary_pkl_path = os.path.join(pickle_src_path,
                                            'test_dictionary.pickle')
    test_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                             'test_tensor_data.pickle')
    test_mention_data_pkl_path = os.path.join(pickle_src_path,
                                              'test_mention_data.pickle')
    if params['transductive']:
        train_tensor_data_pkl_path = os.path.join(pickle_src_path,
                                                  'train_tensor_data.pickle')
        train_mention_data_pkl_path = os.path.join(
            pickle_src_path, 'train_mention_data.pickle')
    if os.path.isfile(test_dictionary_pkl_path):
        print("Loading stored processed entity dictionary...")
        with open(test_dictionary_pkl_path, 'rb') as read_handle:
            test_dictionary = pickle.load(read_handle)
        entity_dictionary_loaded = True
    if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(
            test_mention_data_pkl_path):
        print("Loading stored processed test data...")
        with open(test_tensor_data_pkl_path, 'rb') as read_handle:
            test_tensor_data = pickle.load(read_handle)
        with open(test_mention_data_pkl_path, 'rb') as read_handle:
            mention_data = pickle.load(read_handle)
    else:
        test_samples = utils.read_dataset(data_split, params["data_path"])
        if not entity_dictionary_loaded:
            with open(os.path.join(params["data_path"], 'dictionary.pickle'),
                      'rb') as read_handle:
                test_dictionary = pickle.load(read_handle)

        # Check if dataset has multiple ground-truth labels
        mult_labels = "labels" in test_samples[0].keys()
        if params["filter_unlabeled"]:
            # Filter samples without gold entities
            test_samples = list(
                filter(
                    lambda sample: (len(sample["labels"]) > 0)
                    if mult_labels else (sample["label"] is not None),
                    test_samples))
        logger.info("Read %d test samples." % len(test_samples))

        mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data(
            test_samples,
            test_dictionary,
            tokenizer,
            params["max_context_length"],
            params["max_cand_length"],
            multi_label_key="labels" if mult_labels else None,
            context_key=params["context_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            knn=knn,
            dictionary_processed=entity_dictionary_loaded)
        print("Saving processed test data...")
        if not entity_dictionary_loaded:
            with open(test_dictionary_pkl_path, 'wb') as write_handle:
                pickle.dump(test_dictionary,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            entity_dictionary_loaded = True
        with open(test_tensor_data_pkl_path, 'wb') as write_handle:
            pickle.dump(test_tensor_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
        with open(test_mention_data_pkl_path, 'wb') as write_handle:
            pickle.dump(mention_data,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

    # Store test dictionary token ids
    test_dict_vecs = torch.tensor(list(map(lambda x: x['ids'],
                                           test_dictionary)),
                                  dtype=torch.long)
    # Store test mention token ids
    test_men_vecs = test_tensor_data[:][0]

    n_entities = len(test_dict_vecs)
    n_mentions = len(test_tensor_data)

    if within_doc:
        if test_samples is None:
            test_samples, _ = read_data(data_split, params, logger)
        test_context_doc_ids = [s['context_doc_id'] for s in test_samples]

    if params["transductive"]:
        if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile(
                train_mention_data_pkl_path):
            print("Loading stored processed train data...")
            with open(train_tensor_data_pkl_path, 'rb') as read_handle:
                train_tensor_data = pickle.load(read_handle)
            with open(train_mention_data_pkl_path, 'rb') as read_handle:
                train_mention_data = pickle.load(read_handle)
        else:
            train_samples = utils.read_dataset('train', params["data_path"])

            # Check if dataset has multiple ground-truth labels
            mult_labels = "labels" in train_samples[0].keys()
            logger.info("Read %d test samples." % len(test_samples))

            train_mention_data, _, train_tensor_data = data_process.process_mention_data(
                train_samples,
                test_dictionary,
                tokenizer,
                params["max_context_length"],
                params["max_cand_length"],
                multi_label_key="labels" if mult_labels else None,
                context_key=params["context_key"],
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                knn=knn,
                dictionary_processed=entity_dictionary_loaded)
            print("Saving processed train data...")
            with open(train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_tensor_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(train_mention_data_pkl_path, 'wb') as write_handle:
                pickle.dump(train_mention_data,
                            write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        # Store train mention token ids
        train_men_vecs = train_tensor_data[:][0]
        n_mentions += len(train_tensor_data)
        n_train_mentions = len(train_tensor_data)

    # Values of k to run the evaluation against
    knn_vals = [0] + [2**i for i in range(int(math.log(knn, 2)) + 1)]
    # Store the maximum evaluation k
    max_knn = knn_vals[-1]

    time_start = time.time()

    # Check if graphs are already built
    graph_path = os.path.join(output_path, 'graphs.pickle')
    if not params['only_recall'] and os.path.isfile(graph_path):
        print("Loading stored joint graphs...")
        with open(graph_path, 'rb') as read_handle:
            joint_graphs = pickle.load(read_handle)
    else:
        # Initialize graphs to store mention-mention and mention-entity similarity score edges;
        # Keyed on k, the number of nearest mentions retrieved
        joint_graphs = {}
        for k in knn_vals:
            joint_graphs[k] = {
                'rows': np.array([]),
                'cols': np.array([]),
                'data': np.array([]),
                'shape': (n_entities + n_mentions, n_entities + n_mentions)
            }

        # Check and load stored embedding data
        embed_data_path = os.path.join(embed_data_path, 'embed_data.t7')
        embed_data = None
        if os.path.isfile(embed_data_path):
            embed_data = torch.load(embed_data_path)

        if use_types:
            if embed_data is not None:
                logger.info('Loading stored embeddings and computing indexes')
                dict_embeds = embed_data['dict_embeds']
                if 'dict_idxs_by_type' in embed_data:
                    dict_idxs_by_type = embed_data['dict_idxs_by_type']
                else:
                    dict_idxs_by_type = data_process.get_idxs_by_type(
                        test_dictionary)
                dict_indexes = data_process.get_index_from_embeds(
                    dict_embeds,
                    dict_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                men_embeds = embed_data['men_embeds']
                if 'men_idxs_by_type' in embed_data:
                    men_idxs_by_type = embed_data['men_idxs_by_type']
                else:
                    men_idxs_by_type = data_process.get_idxs_by_type(
                        mention_data)
                men_indexes = data_process.get_index_from_embeds(
                    men_embeds,
                    men_idxs_by_type,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info("Dictionary: Embedding and building index")
                dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    test_dict_vecs,
                    encoder_type="candidate",
                    n_gpu=n_gpu,
                    corpus=test_dictionary,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
                logger.info("Queries: Embedding and building index")
                vecs = test_men_vecs
                men_data = mention_data
                if params['transductive']:
                    vecs = torch.cat((train_men_vecs, vecs), dim=0)
                    men_data = train_mention_data + mention_data
                men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
                    reranker,
                    vecs,
                    encoder_type="context",
                    n_gpu=n_gpu,
                    corpus=men_data,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
        else:
            if embed_data is not None:
                logger.info('Loading stored embeddings and computing indexes')
                dict_embeds = embed_data['dict_embeds']
                dict_index = data_process.get_index_from_embeds(
                    dict_embeds,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
                men_embeds = embed_data['men_embeds']
                men_index = data_process.get_index_from_embeds(
                    men_embeds,
                    force_exact_search=params['force_exact_search'],
                    probe_mult_factor=params['probe_mult_factor'])
            else:
                logger.info("Dictionary: Embedding and building index")
                dict_embeds, dict_index = data_process.embed_and_index(
                    reranker,
                    test_dict_vecs,
                    'candidate',
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])
                logger.info("Queries: Embedding and building index")
                vecs = test_men_vecs
                if params['transductive']:
                    vecs = torch.cat((train_men_vecs, vecs), dim=0)
                men_embeds, men_index = data_process.embed_and_index(
                    reranker,
                    vecs,
                    'context',
                    n_gpu=n_gpu,
                    force_exact_search=params['force_exact_search'],
                    batch_size=params['embed_batch_size'],
                    probe_mult_factor=params['probe_mult_factor'])

        # Save computed embedding data if not loaded from disk
        if embed_data is None:
            embed_data = {}
            embed_data['dict_embeds'] = dict_embeds
            embed_data['men_embeds'] = men_embeds
            if use_types:
                embed_data['dict_idxs_by_type'] = dict_idxs_by_type
                embed_data['men_idxs_by_type'] = men_idxs_by_type
            # NOTE: Cannot pickle faiss index because it is a SwigPyObject
            torch.save(embed_data,
                       embed_data_path,
                       pickle_protocol=pickle.HIGHEST_PROTOCOL)

        recall_accuracy = {
            2**i: 0
            for i in range(int(math.log(params['recall_k'], 2)) + 1)
        }
        recall_idxs = [0.] * params['recall_k']

        logger.info("Starting KNN search...")
        # Fetch recall_k (default 16) knn entities for all mentions
        # Fetch (k+1) NN mention candidates
        if not use_types:
            _men_embeds = men_embeds
            if params['transductive']:
                _men_embeds = _men_embeds[n_train_mentions:]
            nn_ent_dists, nn_ent_idxs = dict_index.search(
                _men_embeds, params['recall_k'])
            n_mens_to_fetch = len(_men_embeds) if within_doc else max_knn + 1
            nn_men_dists, nn_men_idxs = men_index.search(
                _men_embeds, n_mens_to_fetch)
        else:
            query_len = len(men_embeds) - (n_train_mentions
                                           if params['transductive'] else 0)
            nn_ent_idxs = np.zeros((query_len, params['recall_k']))
            nn_ent_dists = np.zeros((query_len, params['recall_k']),
                                    dtype='float64')
            nn_men_idxs = -1 * np.ones((query_len, query_len), dtype=int)
            nn_men_dists = -1 * np.ones(
                (query_len, query_len), dtype='float64')
            for entity_type in men_indexes:
                men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type][
                    men_idxs_by_type[entity_type] >=
                    n_train_mentions]] if params[
                        'transductive'] else men_embeds[
                            men_idxs_by_type[entity_type]]
                nn_ent_dists_by_type, nn_ent_idxs_by_type = dict_indexes[
                    entity_type].search(men_embeds_by_type, params['recall_k'])
                nn_ent_idxs_by_type = np.array(
                    list(
                        map(lambda x: dict_idxs_by_type[entity_type][x],
                            nn_ent_idxs_by_type)))
                n_mens_to_fetch = len(
                    men_embeds_by_type) if within_doc else max_knn + 1
                nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[
                    entity_type].search(
                        men_embeds_by_type,
                        min(n_mens_to_fetch, len(men_embeds_by_type)))
                nn_men_idxs_by_type = np.array(
                    list(
                        map(lambda x: men_idxs_by_type[entity_type][x],
                            nn_men_idxs_by_type)))
                i = -1
                for idx in men_idxs_by_type[entity_type]:
                    if params['transductive']:
                        idx -= n_train_mentions
                    if idx < 0:
                        continue
                    i += 1
                    nn_ent_idxs[idx] = nn_ent_idxs_by_type[i]
                    nn_ent_dists[idx] = nn_ent_dists_by_type[i]
                    nn_men_idxs[idx][:len(nn_men_idxs_by_type[i]
                                          )] = nn_men_idxs_by_type[i]
                    nn_men_dists[idx][:len(nn_men_dists_by_type[i]
                                           )] = nn_men_dists_by_type[i]
        logger.info("Search finished")

        logger.info('Building graphs')
        # Find the most similar entity and k-nn mentions for each mention query
        for idx in range(len(nn_ent_idxs)):
            # Get nearest entity candidate
            dict_cand_idx = nn_ent_idxs[idx][0]
            dict_cand_score = nn_ent_dists[idx][0]
            # Compute recall metric
            gold_idxs = mention_data[idx][
                "label_idxs"][:mention_data[idx]["n_labels"]]
            recall_idx = np.argwhere(nn_ent_idxs[idx] == gold_idxs[0])
            if len(recall_idx) != 0:
                recall_idx = int(recall_idx)
                recall_idxs[recall_idx] += 1.
                for recall_k in recall_accuracy:
                    if recall_idx < recall_k:
                        recall_accuracy[recall_k] += 1.
            if not params['only_recall']:
                filter_mask_neg1 = nn_men_idxs[idx] != -1
                men_cand_idxs = nn_men_idxs[idx][filter_mask_neg1]
                men_cand_scores = nn_men_dists[idx][filter_mask_neg1]

                if within_doc:
                    men_cand_idxs, wd_mask = filter_by_context_doc_id(
                        men_cand_idxs,
                        test_context_doc_ids[idx],
                        test_context_doc_ids,
                        return_numpy=True)
                    men_cand_scores = men_cand_scores[wd_mask]

                # Filter candidates to remove mention query and keep only the top k candidates
                filter_mask = men_cand_idxs != idx
                men_cand_idxs, men_cand_scores = men_cand_idxs[
                    filter_mask][:max_knn], men_cand_scores[
                        filter_mask][:max_knn]

                if params['transductive']:
                    idx += n_train_mentions
                # Add edges to the graphs
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    # Add mention-entity edge
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'],
                        [n_entities + idx
                         ])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(joint_graph['cols'],
                                                    dict_cand_idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    dict_cand_score)
                    if k > 0:
                        # Add mention-mention edges
                        joint_graph['rows'] = np.append(
                            joint_graph['rows'],
                            [n_entities + idx] * len(men_cand_idxs[:k]))
                        joint_graph['cols'] = np.append(
                            joint_graph['cols'],
                            n_entities + men_cand_idxs[:k])
                        joint_graph['data'] = np.append(
                            joint_graph['data'], men_cand_scores[:k])

        if params['transductive']:
            # Add positive infinity mention-entity edges from training queries to labeled entities
            for idx, train_men in enumerate(train_mention_data):
                dict_cand_idx = train_men["label_idxs"][0]
                for k in joint_graphs:
                    joint_graph = joint_graphs[k]
                    joint_graph['rows'] = np.append(
                        joint_graph['rows'],
                        [n_entities + idx
                         ])  # Mentions added at an offset of maximum entities
                    joint_graph['cols'] = np.append(joint_graph['cols'],
                                                    dict_cand_idx)
                    joint_graph['data'] = np.append(joint_graph['data'],
                                                    float('inf'))

        # Compute and print recall metric
        recall_idx_mode = np.argmax(recall_idxs)
        recall_idx_mode_prop = recall_idxs[recall_idx_mode] / np.sum(
            recall_idxs)
        logger.info(f"""
        Recall metrics (for {len(nn_ent_idxs)} queries):
        ---------------""")
        logger.info(
            f"highest recall idx = {recall_idx_mode} ({recall_idxs[recall_idx_mode]}/{np.sum(recall_idxs)} = {recall_idx_mode_prop})"
        )
        for recall_k in recall_accuracy:
            recall_accuracy[recall_k] /= len(nn_ent_idxs)
            logger.info(f"recall@{recall_k} = {recall_accuracy[recall_k]}")

        if params['only_recall']:
            exit()

        # Pickle the graphs
        print("Saving joint graphs...")
        with open(graph_path, 'wb') as write_handle:
            pickle.dump(joint_graphs,
                        write_handle,
                        protocol=pickle.HIGHEST_PROTOCOL)

        if params['only_embed_and_build']:
            logger.info(f"Saved embedding data at: {embed_data_path}")
            logger.info(f"Saved graphs at: {graph_path}")
            exit()

    graph_mode = params.get('graph_mode', None)

    result_overview = {
        'n_entities':
        n_entities,
        'n_mentions':
        n_mentions - (n_train_mentions if params['transductive'] else 0)
    }
    results = {}
    if graph_mode is None or graph_mode not in ['directed', 'undirected']:
        results['directed'] = []
        results['undirected'] = []
    else:
        results[graph_mode] = []

    knn_fetch_time = time.time() - time_start
    graph_processing_time = time.time()
    n_graphs_processed = 0.

    for mode in results:
        print(f'\nEvaluation mode: {mode.upper()}')
        for k in joint_graphs:
            if k <= knn:
                print(f"\nGraph (k={k}):")
                # Partition graph based on cluster-linking constraints
                partitioned_graph, clusters = partition_graph(
                    joint_graphs[k],
                    n_entities,
                    mode == 'directed',
                    return_clusters=True)
                # Infer predictions from clusters
                result = analyzeClusters(
                    clusters, test_dictionary, mention_data, k,
                    n_train_mentions if params['transductive'] else 0)
                # Store result
                results[mode].append(result)
                n_graphs_processed += 1

    avg_graph_processing_time = (time.time() -
                                 graph_processing_time) / n_graphs_processed
    avg_per_graph_time = (knn_fetch_time + avg_graph_processing_time) / 60

    execution_time = (time.time() - time_start) / 60
    # Store results
    output_file_name = os.path.join(
        output_path,
        f"eval_results_{__import__('calendar').timegm(__import__('time').gmtime())}"
    )

    try:
        for recall_k in recall_accuracy:
            result_overview[f'recall@{recall_k}'] = recall_accuracy[recall_k]
    except:
        logger.info(
            "Recall data not available since graphs were loaded from disk")

    for mode in results:
        mode_results = results[mode]
        result_overview[mode] = {}
        for r in mode_results:
            k = r['knn_mentions']
            result_overview[mode][f'accuracy@knn{k}'] = r['accuracy']
            logger.info(f"{mode} accuracy@knn{k} = {r['accuracy']}")
            output_file = f'{output_file_name}-{mode}-{k}.json'
            with open(output_file, 'w') as f:
                json.dump(r, f, indent=2)
                print(
                    f"\nPredictions ({mode}) @knn{k} saved at: {output_file}")
    with open(f'{output_file_name}.json', 'w') as f:
        json.dump(result_overview, f, indent=2)
        print(f"\nPredictions overview saved at: {output_file_name}.json")

    logger.info("\nThe avg. per graph evaluation time is {} minutes\n".format(
        avg_per_graph_time))
    logger.info(
        "\nThe total evaluation took {} minutes\n".format(execution_time))