def main(params): output_path = params["output_path"] if not os.path.exists(output_path): os.makedirs(output_path) logger = utils.get_logger(params["output_path"], 'log-eval') pickle_src_path = params["pickle_src_path"] if pickle_src_path is None or not os.path.exists(pickle_src_path): pickle_src_path = output_path embed_data_path = params["embed_data_path"] if embed_data_path is None or not os.path.exists(embed_data_path): embed_data_path = output_path cands_path = os.path.join(output_path, 'cands.pickle') # Init model reranker = BiEncoderRanker(params) reranker.model.eval() tokenizer = reranker.tokenizer n_gpu = reranker.n_gpu knn = params["knn"] use_types = params["use_types"] within_doc = params["within_doc"] data_split = params["data_split"] # Default = "test" # Load test data test_samples = None entity_dictionary_loaded = False test_dictionary_pkl_path = os.path.join(pickle_src_path, 'test_dictionary.pickle') test_tensor_data_pkl_path = os.path.join(pickle_src_path, 'test_tensor_data.pickle') test_mention_data_pkl_path = os.path.join(pickle_src_path, 'test_mention_data.pickle') if params['transductive']: train_tensor_data_pkl_path = os.path.join(pickle_src_path, 'train_tensor_data.pickle') train_mention_data_pkl_path = os.path.join( pickle_src_path, 'train_mention_data.pickle') if os.path.isfile(test_dictionary_pkl_path): print("Loading stored processed entity dictionary...") with open(test_dictionary_pkl_path, 'rb') as read_handle: test_dictionary = pickle.load(read_handle) entity_dictionary_loaded = True if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile( test_mention_data_pkl_path): print("Loading stored processed test data...") with open(test_tensor_data_pkl_path, 'rb') as read_handle: test_tensor_data = pickle.load(read_handle) with open(test_mention_data_pkl_path, 'rb') as read_handle: mention_data = pickle.load(read_handle) else: test_samples = utils.read_dataset(data_split, params["data_path"]) if not entity_dictionary_loaded: with open(os.path.join(params["data_path"], 'dictionary.pickle'), 'rb') as read_handle: test_dictionary = pickle.load(read_handle) # Check if dataset has multiple ground-truth labels mult_labels = "labels" in test_samples[0].keys() if params["filter_unlabeled"]: # Filter samples without gold entities test_samples = list( filter( lambda sample: (len(sample["labels"]) > 0) if mult_labels else (sample["label"] is not None), test_samples)) logger.info("Read %d test samples." % len(test_samples)) mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data( test_samples, test_dictionary, tokenizer, params["max_context_length"], params["max_cand_length"], multi_label_key="labels" if mult_labels else None, context_key=params["context_key"], silent=params["silent"], logger=logger, debug=params["debug"], knn=knn, dictionary_processed=entity_dictionary_loaded) print("Saving processed test data...") if not entity_dictionary_loaded: with open(test_dictionary_pkl_path, 'wb') as write_handle: pickle.dump(test_dictionary, write_handle, protocol=pickle.HIGHEST_PROTOCOL) entity_dictionary_loaded = True with open(test_tensor_data_pkl_path, 'wb') as write_handle: pickle.dump(test_tensor_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(test_mention_data_pkl_path, 'wb') as write_handle: pickle.dump(mention_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) # Store test dictionary token ids test_dict_vecs = torch.tensor(list(map(lambda x: x['ids'], test_dictionary)), dtype=torch.long) # Store test mention token ids test_men_vecs = test_tensor_data[:][0] n_entities = len(test_dict_vecs) n_mentions = len(test_tensor_data) if within_doc: if test_samples is None: test_samples, _ = read_data(data_split, params, logger) test_context_doc_ids = [s['context_doc_id'] for s in test_samples] if params["transductive"]: if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile( train_mention_data_pkl_path): print("Loading stored processed train data...") with open(train_tensor_data_pkl_path, 'rb') as read_handle: train_tensor_data = pickle.load(read_handle) with open(train_mention_data_pkl_path, 'rb') as read_handle: train_mention_data = pickle.load(read_handle) else: train_samples = utils.read_dataset('train', params["data_path"]) # Check if dataset has multiple ground-truth labels mult_labels = "labels" in train_samples[0].keys() logger.info("Read %d test samples." % len(test_samples)) train_mention_data, _, train_tensor_data = data_process.process_mention_data( train_samples, test_dictionary, tokenizer, params["max_context_length"], params["max_cand_length"], multi_label_key="labels" if mult_labels else None, context_key=params["context_key"], silent=params["silent"], logger=logger, debug=params["debug"], knn=knn, dictionary_processed=entity_dictionary_loaded) print("Saving processed train data...") with open(train_tensor_data_pkl_path, 'wb') as write_handle: pickle.dump(train_tensor_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(train_mention_data_pkl_path, 'wb') as write_handle: pickle.dump(train_mention_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) # Store train mention token ids train_men_vecs = train_tensor_data[:][0] n_mentions += len(train_tensor_data) n_train_mentions = len(train_tensor_data) # if os.path.isfile(cands_path): # print("Loading stored candidates...") # with open(cands_path, 'rb') as read_handle: # men_cands = pickle.load(read_handle) # else: # Check and load stored embedding data embed_data_path = os.path.join(embed_data_path, 'embed_data.t7') embed_data = None if os.path.isfile(embed_data_path): embed_data = torch.load(embed_data_path) if use_types: if embed_data is not None: logger.info('Loading stored embeddings and computing indexes') dict_embeds = embed_data['dict_embeds'] if 'dict_idxs_by_type' in embed_data: dict_idxs_by_type = embed_data['dict_idxs_by_type'] else: dict_idxs_by_type = data_process.get_idxs_by_type( test_dictionary) dict_indexes = data_process.get_index_from_embeds( dict_embeds, dict_idxs_by_type, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) men_embeds = embed_data['men_embeds'] if 'men_idxs_by_type' in embed_data: men_idxs_by_type = embed_data['men_idxs_by_type'] else: men_idxs_by_type = data_process.get_idxs_by_type(mention_data) men_indexes = data_process.get_index_from_embeds( men_embeds, men_idxs_by_type, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) else: logger.info("Dictionary: Embedding and building index") dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index( reranker, test_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, corpus=test_dictionary, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) logger.info("Queries: Embedding and building index") vecs = test_men_vecs men_data = mention_data if params['transductive']: vecs = torch.cat((train_men_vecs, vecs), dim=0) men_data = train_mention_data + mention_data men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index( reranker, vecs, encoder_type="context", n_gpu=n_gpu, corpus=men_data, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) else: if embed_data is not None: logger.info('Loading stored embeddings and computing indexes') dict_embeds = embed_data['dict_embeds'] dict_index = data_process.get_index_from_embeds( dict_embeds, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) men_embeds = embed_data['men_embeds'] men_index = data_process.get_index_from_embeds( men_embeds, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) else: logger.info("Dictionary: Embedding and building index") dict_embeds, dict_index = data_process.embed_and_index( reranker, test_dict_vecs, 'candidate', n_gpu=n_gpu, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) logger.info("Queries: Embedding and building index") vecs = test_men_vecs if params['transductive']: vecs = torch.cat((train_men_vecs, vecs), dim=0) men_embeds, men_index = data_process.embed_and_index( reranker, vecs, 'context', n_gpu=n_gpu, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) # Save computed embedding data if not loaded from disk if embed_data is None: embed_data = {} embed_data['dict_embeds'] = dict_embeds embed_data['men_embeds'] = men_embeds if use_types: embed_data['dict_idxs_by_type'] = dict_idxs_by_type embed_data['men_idxs_by_type'] = men_idxs_by_type # NOTE: Cannot pickle faiss index because it is a SwigPyObject torch.save(embed_data, embed_data_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) logger.info("Starting KNN search...") # Fetch (k+1) NN mention candidates if not use_types: _men_embeds = men_embeds if params['transductive']: _men_embeds = _men_embeds[n_train_mentions:] n_mens_to_fetch = len(_men_embeds) if within_doc else knn + 1 nn_men_dists, nn_men_idxs = men_index.search(_men_embeds, n_mens_to_fetch) else: query_len = len(men_embeds) - (n_train_mentions if params['transductive'] else 0) nn_men_idxs = -1 * np.ones((query_len, query_len), dtype=int) nn_men_dists = -1 * np.ones((query_len, query_len), dtype='float64') for entity_type in men_indexes: men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type][ men_idxs_by_type[entity_type] >= n_train_mentions]] if params['transductive'] else men_embeds[ men_idxs_by_type[entity_type]] n_mens_to_fetch = len( men_embeds_by_type) if within_doc else knn + 1 nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[ entity_type].search( men_embeds_by_type, min(n_mens_to_fetch, len(men_embeds_by_type))) nn_men_idxs_by_type = np.array( list( map(lambda x: men_idxs_by_type[entity_type][x], nn_men_idxs_by_type))) i = -1 for idx in men_idxs_by_type[entity_type]: if params['transductive']: idx -= n_train_mentions if idx < 0: continue i += 1 nn_men_idxs[idx][:len(nn_men_idxs_by_type[i] )] = nn_men_idxs_by_type[i] nn_men_dists[idx][:len(nn_men_dists_by_type[i] )] = nn_men_dists_by_type[i] logger.info("Search finished") logger.info(f'Calculating mention recall@{knn}') # Find the most similar entity and k-nn mentions for each mention query men_recall_knn = [] men_cands = [] cui_sums = defaultdict(int) for m in mention_data: cui_sums[m['label_cuis'][0]] += 1 for idx in range(len(nn_men_idxs)): filter_mask_neg1 = nn_men_idxs[idx] != -1 men_cand_idxs = nn_men_idxs[idx][filter_mask_neg1] men_cand_scores = nn_men_dists[idx][filter_mask_neg1] if within_doc: men_cand_idxs, wd_mask = filter_by_context_doc_id( men_cand_idxs, test_context_doc_ids[idx], test_context_doc_ids, return_numpy=True) men_cand_scores = men_cand_scores[wd_mask] # Filter candidates to remove mention query and keep only the top k candidates filter_mask = men_cand_idxs != idx men_cand_idxs, men_cand_scores = men_cand_idxs[ filter_mask][:knn], men_cand_scores[filter_mask][:knn] assert len(men_cand_idxs) == knn men_cands.append(men_cand_idxs) # Calculate mention recall@k gold_cui = mention_data[idx]['label_cuis'][0] if cui_sums[gold_cui] > 1: recall_hit = [ 1. for midx in men_cand_idxs if mention_data[midx]['label_cuis'][0] == gold_cui ] recall_knn = sum(recall_hit) / min(cui_sums[gold_cui] - 1, knn) men_recall_knn.append(recall_knn) logger.info('Done') # assert len(men_recall_knn) == len(nn_men_idxs) # Pickle the graphs print(f"Saving top-{knn} candidates") with open(cands_path, 'wb') as write_handle: pickle.dump(men_cands, write_handle, protocol=pickle.HIGHEST_PROTOCOL) # Output final results mean_recall_knn = np.mean(men_recall_knn) logger.info(f"Result: Final mean recall@{knn} = {mean_recall_knn}")
def main(params): output_path = params["output_path"] if not os.path.exists(output_path): os.makedirs(output_path) logger = utils.get_logger(params["output_path"], 'log') pickle_src_path = params["pickle_src_path"] if pickle_src_path is None or not os.path.exists(pickle_src_path): pickle_src_path = output_path embed_data_path = params["embed_data_path"] if embed_data_path is None or not os.path.exists(embed_data_path): embed_data_path = output_path # Init model reranker = BiEncoderRanker(params) reranker.model.eval() tokenizer = reranker.tokenizer n_gpu = reranker.n_gpu knn = params["knn"] # Use as the max-knn value for the graph construction use_types = params["use_types"] # within_doc = params["within_doc"] data_split = params["data_split"] # Default = "test" # Load test data entity_dictionary_loaded = False test_dictionary_pkl_path = os.path.join(pickle_src_path, 'test_dictionary.pickle') test_tensor_data_pkl_path = os.path.join(pickle_src_path, 'test_tensor_data.pickle') test_mention_data_pkl_path = os.path.join(pickle_src_path, 'test_mention_data.pickle') # if params['transductive']: # train_tensor_data_pkl_path = os.path.join(pickle_src_path, 'train_tensor_data.pickle') # train_mention_data_pkl_path = os.path.join(pickle_src_path, 'train_mention_data.pickle') if os.path.isfile(test_dictionary_pkl_path): print("Loading stored processed entity dictionary...") with open(test_dictionary_pkl_path, 'rb') as read_handle: test_dictionary = pickle.load(read_handle) entity_dictionary_loaded = True if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile( test_mention_data_pkl_path): print("Loading stored processed test data...") with open(test_tensor_data_pkl_path, 'rb') as read_handle: test_tensor_data = pickle.load(read_handle) with open(test_mention_data_pkl_path, 'rb') as read_handle: mention_data = pickle.load(read_handle) else: test_samples = utils.read_dataset(data_split, params["data_path"]) if not entity_dictionary_loaded: with open(os.path.join(params["data_path"], 'dictionary.pickle'), 'rb') as read_handle: test_dictionary = pickle.load(read_handle) # Check if dataset has multiple ground-truth labels mult_labels = "labels" in test_samples[0].keys() # Filter samples without gold entities test_samples = list( filter( lambda sample: (len(sample["labels"]) > 0) if mult_labels else (sample["label"] is not None), test_samples)) logger.info("Read %d test samples." % len(test_samples)) mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data( test_samples, test_dictionary, tokenizer, params["max_context_length"], params["max_cand_length"], multi_label_key="labels" if mult_labels else None, context_key=params["context_key"], silent=params["silent"], logger=logger, debug=params["debug"], knn=knn, dictionary_processed=entity_dictionary_loaded) print("Saving processed test data...") if not entity_dictionary_loaded: with open(test_dictionary_pkl_path, 'wb') as write_handle: pickle.dump(test_dictionary, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(test_tensor_data_pkl_path, 'wb') as write_handle: pickle.dump(test_tensor_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(test_mention_data_pkl_path, 'wb') as write_handle: pickle.dump(mention_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) # Reducing the entity dictionary to only the ground truth of the mention queries # Combining the entities and mentions into one structure for joint embedding and indexing new_ents = {} new_ents_arr = [] men_labels = [] for men in mention_data: ent = men['label_idxs'][0] if ent not in new_ents: new_ents[ent] = len(new_ents_arr) new_ents_arr.append(ent) men_labels.append(new_ents[ent]) ent_labels = [i for i in range(len(new_ents_arr))] new_ent_vecs = torch.tensor( list(map(lambda x: test_dictionary[x]['ids'], new_ents_arr))) new_ent_types = list( map(lambda x: {"type": test_dictionary[x]['type']}, new_ents_arr)) test_men_vecs = test_tensor_data[:][0] n_mentions = len(test_tensor_data) n_entities = len(new_ent_vecs) n_embeds = n_mentions + n_entities leaf_labels = np.array(ent_labels + men_labels, dtype=int) all_vecs = torch.cat((new_ent_vecs, test_men_vecs)) all_types = new_ent_types + mention_data # Array of dicts containing key "type" for selected ents and all mentions # Values of k to run the evaluation against knn_vals = [25 * 2**i for i in range(int(math.log(knn / 25, 2)) + 1) ] if params["exact_knn"] is None else [params["exact_knn"]] # Store the maximum evaluation k max_knn = knn_vals[-1] time_start = time.time() # Check if graphs are already built graph_path = os.path.join(output_path, 'graphs.pickle') if os.path.isfile(graph_path): print("Loading stored joint graphs...") with open(graph_path, 'rb') as read_handle: joint_graphs = pickle.load(read_handle) else: # Initialize graphs to store mention-mention and mention-entity similarity score edges; # Keyed on k, the number of nearest mentions retrieved joint_graphs = {} for k in knn_vals: joint_graphs[k] = { 'rows': np.array([]), 'cols': np.array([]), 'data': np.array([]), 'shape': (n_embeds, n_embeds) } # Check and load stored embedding data embed_data_path = os.path.join(embed_data_path, 'embed_data.t7') embed_data = None if os.path.isfile(embed_data_path): embed_data = torch.load(embed_data_path) if use_types: if embed_data is not None: logger.info('Loading stored embeddings') embeds = embed_data['embeds'] if 'idxs_by_type' in embed_data: idxs_by_type = embed_data['idxs_by_type'] else: idxs_by_type = data_process.get_idxs_by_type(all_types) else: logger.info("Embedding data") dict_embeds = data_process.embed_and_index( reranker, all_vecs[:n_entities], encoder_type='candidate', only_embed=True, n_gpu=n_gpu, batch_size=params['embed_batch_size']) men_embeds = data_process.embed_and_index( reranker, all_vecs[n_entities:], encoder_type='context', only_embed=True, n_gpu=n_gpu, batch_size=params['embed_batch_size']) embeds = np.concatenate((dict_embeds, men_embeds), axis=0) idxs_by_type = data_process.get_idxs_by_type(all_types) search_indexes = data_process.get_index_from_embeds( embeds, corpus_idxs=idxs_by_type, force_exact_search=True) else: if embed_data is not None: logger.info('Loading stored embeddings') embeds = embed_data['embeds'] else: logger.info("Embedding data") dict_embeds = data_process.embed_and_index( reranker, all_vecs[:n_entities], encoder_type='candidate', only_embed=True, n_gpu=n_gpu, batch_size=params['embed_batch_size']) men_embeds = data_process.embed_and_index( reranker, all_vecs[n_entities:], encoder_type='context', only_embed=True, n_gpu=n_gpu, batch_size=params['embed_batch_size']) embeds = np.concatenate((dict_embeds, men_embeds), axis=0) search_index = data_process.get_index_from_embeds( embeds, force_exact_search=True) # Save computed embedding data if not loaded from disk if embed_data is None: embed_data = {} embed_data['embeds'] = embeds if use_types: embed_data['idxs_by_type'] = idxs_by_type # NOTE: Cannot pickle faiss index because it is a SwigPyObject torch.save(embed_data, embed_data_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) # Build faiss search index if params["normalize_embeds"]: embeds = normalize(embeds, axis=0) logger.info("Building KNN index...") if use_types: search_indexes = data_process.get_index_from_embeds( embeds, corpus_idxs=idxs_by_type, force_exact_search=True) else: search_index = data_process.get_index_from_embeds( embeds, force_exact_search=True) logger.info("Starting KNN search...") if not use_types: faiss_dists, faiss_idxs = search_index.search(embeds, max_knn + 1) else: query_len = n_embeds faiss_idxs = np.zeros((query_len, max_knn + 1)) faiss_dists = np.zeros((query_len, max_knn + 1), dtype=float) for entity_type in search_indexes: embeds_by_type = embeds[idxs_by_type[entity_type]] nn_dists_by_type, nn_idxs_by_type = search_indexes[ entity_type].search(embeds_by_type, max_knn + 1) for i, idx in enumerate(idxs_by_type[entity_type]): faiss_idxs[idx] = nn_idxs_by_type[i] faiss_dists[idx] = nn_dists_by_type[i] logger.info("Search finished") logger.info('Building graphs') # Find the most similar nodes for each mention and node in the set (minus self) for idx in trange(n_embeds): # Compute adjacent node edge weight if idx != 0: adj_idx = idx - 1 adj_data = embeds[adj_idx] @ embeds[idx] nn_idxs = faiss_idxs[idx] nn_scores = faiss_dists[idx] # Filter candidates to remove mention query and keep only the top k candidates filter_mask = nn_idxs != idx nn_idxs, nn_scores = nn_idxs[filter_mask][:max_knn], nn_scores[ filter_mask][:max_knn] # Add edges to the graphs for k in joint_graphs: # Add edge to adjacent node to force the graph to be connected if idx != 0: joint_graph['rows'] = np.append(joint_graph['rows'], adj_idx) joint_graph['cols'] = np.append(joint_graph['cols'], idx) joint_graph['data'] = np.append(joint_graph['data'], adj_data) joint_graph = joint_graphs[k] # Add mention-mention edges joint_graph['rows'] = np.append(joint_graph['rows'], [idx] * k) joint_graph['cols'] = np.append(joint_graph['cols'], nn_idxs[:k]) joint_graph['data'] = np.append(joint_graph['data'], nn_scores[:k]) knn_fetch_time = time.time() - time_start # Pickle the graphs print("Saving joint graphs...") with open(graph_path, 'wb') as write_handle: pickle.dump(joint_graphs, write_handle, protocol=pickle.HIGHEST_PROTOCOL) if params['only_embed_and_build']: logger.info(f"Saved embedding data at: {embed_data_path}") logger.info(f"Saved graphs at: {graph_path}") exit() results = { 'n_leaves': n_embeds, 'n_entities': n_entities, 'n_mentions': n_mentions } graph_processing_time = time.time() n_graphs_processed = 0. linkage_fns = ["single", "complete", "average"] if params["linkage"] is None \ else [params["linkage"]] # Different HAC linkage functions to run the analyses over for fn in linkage_fns: logger.info(f"Linkage function: {fn}") purities = [] fn_result = {} for k in joint_graphs: graph = hg.UndirectedGraph(n_embeds) graph.add_edges(joint_graphs[k]['rows'], joint_graphs[k]['cols']) weights = -joint_graphs[k][ 'data'] # Since Higra expects weights as distances, not similarity tree = get_hac_tree(graph, weights, linkage=fn) purity = hg.dendrogram_purity(tree, leaf_labels) fn_result[f"purity@{k}nn"] = purity logger.info(f"purity@{k}nn = {purity}") purities.append(purity) n_graphs_processed += 1 fn_result["average"] = round(np.mean(purities), 4) logger.info(f"average = {fn_result['average']}") results[fn] = fn_result avg_graph_processing_time = (time.time() - graph_processing_time) / n_graphs_processed avg_per_graph_time = (knn_fetch_time + avg_graph_processing_time) / 60 execution_time = (time.time() - time_start) / 60 # Store results output_file_name = os.path.join( output_path, f"results_{__import__('calendar').timegm(__import__('time').gmtime())}" ) logger.info(f"Results: \n {results}") with open(f'{output_file_name}.json', 'w') as f: json.dump(results, f, indent=2) print(f"\nResults saved at: {output_file_name}.json") logger.info("\nThe avg. per graph evaluation time is {} minutes\n".format( avg_per_graph_time)) logger.info( "\nThe total evaluation took {} minutes\n".format(execution_time))
def save_topk_biencoder_cands(bi_reranker, use_types, logger, n_gpu, params, bi_tokenizer, max_context_length, max_cand_length, pickle_src_path, topk=64): entity_dictionary = load_data('train', bi_tokenizer, max_context_length, max_cand_length, 1, pickle_src_path, params, logger, return_dict_only=True) entity_dict_vecs = torch.tensor(list( map(lambda x: x['ids'], entity_dictionary)), dtype=torch.long) logger.info('Biencoder: Embedding and indexing entity dictionary') if use_types: _, dict_indexes, dict_idxs_by_type = data_process.embed_and_index( bi_reranker, entity_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, corpus=entity_dictionary, force_exact_search=True, batch_size=params['embed_batch_size']) else: _, dict_index = data_process.embed_and_index( bi_reranker, entity_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, force_exact_search=True, batch_size=params['embed_batch_size']) logger.info('Biencoder: Embedding and indexing finished') for mode in ["train", "valid", "test"]: logger.info( f"Biencoder: Fetching top-{topk} biencoder candidates for {mode} set" ) _, tensor_data, processed_data = load_data(mode, bi_tokenizer, max_context_length, max_cand_length, 1, pickle_src_path, params, logger) men_vecs = tensor_data[:][0] logger.info('Biencoder: Embedding mention data') if use_types: men_embeddings, _, men_idxs_by_type = data_process.embed_and_index( bi_reranker, men_vecs, encoder_type="context", n_gpu=n_gpu, corpus=processed_data, force_exact_search=True, batch_size=params['embed_batch_size']) else: men_embeddings = data_process.embed_and_index( bi_reranker, men_vecs, encoder_type="context", n_gpu=n_gpu, force_exact_search=True, batch_size=params['embed_batch_size'], only_embed=True) logger.info('Biencoder: Embedding finished') logger.info("Biencoder: Finding nearest entities for each mention...") if not use_types: _, bi_dict_nns = dict_index.search(men_embeddings, topk) else: bi_dict_nns = np.zeros((len(men_embeddings), topk), dtype=int) for entity_type in men_idxs_by_type: men_embeds_by_type = men_embeddings[ men_idxs_by_type[entity_type]] _, dict_nns_by_type = dict_indexes[entity_type].search( men_embeds_by_type, topk) dict_nns_idxs = np.array( list( map(lambda x: dict_idxs_by_type[entity_type][x], dict_nns_by_type))) for i, idx in enumerate(men_idxs_by_type[entity_type]): bi_dict_nns[idx] = dict_nns_idxs[i] logger.info("Biencoder: Search finished") labels = [-1] * len(bi_dict_nns) for men_idx in range(len(bi_dict_nns)): gold_idx = processed_data[men_idx]["label_idxs"][0] for i in range(len(bi_dict_nns[men_idx])): if bi_dict_nns[men_idx][i] == gold_idx: labels[men_idx] = i break logger.info( f"Biencoder: Saving top-{topk} biencoder candidates for {mode} set" ) save_data_path = os.path.join(params['output_path'], f'candidates_{mode}_top{topk}.t7') torch.save({ "mode": mode, "candidates": bi_dict_nns, "labels": labels }, save_data_path) logger.info("Biencoder: Saved")
def main(params): output_path = params["output_path"] if not os.path.exists(output_path): os.makedirs(output_path) logger = utils.get_logger(params["output_path"]) pickle_src_path = params["pickle_src_path"] if pickle_src_path is None or not os.path.exists(pickle_src_path): pickle_src_path = output_path # Init model reranker = BiEncoderRanker(params) reranker.model.eval() tokenizer = reranker.tokenizer model = reranker.model device = reranker.device n_gpu = reranker.n_gpu knn = params["knn"] directed_graph = params["directed_graph"] use_types = params["use_types"] data_split = params["data_split"] # Parameter default is "test" # Load test data entity_dictionary_loaded = False test_dictionary_pkl_path = os.path.join(pickle_src_path, 'test_dictionary.pickle') test_tensor_data_pkl_path = os.path.join(pickle_src_path, 'test_tensor_data.pickle') test_mention_data_pkl_path = os.path.join(pickle_src_path, 'test_mention_data.pickle') if os.path.isfile(test_dictionary_pkl_path): print("Loading stored processed entity dictionary...") with open(test_dictionary_pkl_path, 'rb') as read_handle: test_dictionary = pickle.load(read_handle) entity_dictionary_loaded = True if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile(test_mention_data_pkl_path): print("Loading stored processed test data...") with open(test_tensor_data_pkl_path, 'rb') as read_handle: test_tensor_data = pickle.load(read_handle) with open(test_mention_data_pkl_path, 'rb') as read_handle: mention_data = pickle.load(read_handle) else: test_samples = utils.read_dataset(data_split, params["data_path"]) if not entity_dictionary_loaded: with open(os.path.join(params["data_path"], 'dictionary.pickle'), 'rb') as read_handle: test_dictionary = pickle.load(read_handle) # Check if dataset has multiple ground-truth labels mult_labels = "labels" in test_samples[0].keys() if params["filter_unlabeled"]: # Filter samples without gold entities test_samples = list(filter(lambda sample: (len(sample["labels"]) > 0) if mult_labels else (sample["label"] is not None), test_samples)) logger.info("Read %d test samples." % len(test_samples)) mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data( test_samples, test_dictionary, tokenizer, params["max_context_length"], params["max_cand_length"], multi_label_key="labels" if mult_labels else None, context_key=params["context_key"], silent=params["silent"], logger=logger, debug=params["debug"], knn=knn, dictionary_processed=entity_dictionary_loaded ) print("Saving processed test data...") if not entity_dictionary_loaded: with open(test_dictionary_pkl_path, 'wb') as write_handle: pickle.dump(test_dictionary, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(test_tensor_data_pkl_path, 'wb') as write_handle: pickle.dump(test_tensor_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(test_mention_data_pkl_path, 'wb') as write_handle: pickle.dump(mention_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) # Store test dictionary token ids test_dict_vecs = torch.tensor( list(map(lambda x: x['ids'], test_dictionary)), dtype=torch.long) # Store test mention token ids test_men_vecs = test_tensor_data[:][0] n_entities = len(test_dict_vecs) n_mentions = len(test_tensor_data) # Values of k to run the evaluation against knn_vals = [0] + [2**i for i in range(int(math.log(knn, 2)) + 1)] # Store the maximum evaluation k max_knn = knn_vals[-1] # Check if graphs are already built graph_path = os.path.join(output_path, 'graphs.pickle') if not params['only_recall'] and os.path.isfile(graph_path): print("Loading stored joint graphs...") with open(graph_path, 'rb') as read_handle: joint_graphs = pickle.load(read_handle) else: # Initialize graphs to store mention-mention and mention-entity similarity score edges; # Keyed on k, the number of nearest mentions retrieved joint_graphs = {} for k in knn_vals: joint_graphs[k] = { 'rows': np.array([]), 'cols': np.array([]), 'data': np.array([]), 'shape': (n_entities+n_mentions, n_entities+n_mentions) } if use_types: print("Dictionary: Embedding and building index") dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(reranker, test_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, corpus=test_dictionary, force_exact_search=True) # Verifiy embeddings og_embeds = torch.load('models/trained/zeshel_og/eval/data_og/cand_encodes.t7') world_to_type = {12:'forgotten_realms', 13:'lego', 14:'star_trek', 15:'yugioh'} for world in og_embeds: for i,oge in enumerate(tqdm(og_embeds[world])): dict_embed_idx = dict_idxs_by_type[world_to_type[world]][i] try: assert torch.eq(oge, torch.tensor(dict_embeds[dict_embed_idx])) except: embed() exit() print('PASS') exit() print("Queries: Embedding and building index") men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(reranker, test_men_vecs, encoder_type="context", n_gpu=n_gpu, corpus=mention_data, force_exact_search=True) else: print("Dictionary: Embedding and building index") dict_embeds, dict_index = data_process.embed_and_index( reranker, test_dict_vecs, 'candidate', n_gpu=n_gpu) print("Queries: Embedding and building index") men_embeds, men_index = data_process.embed_and_index( reranker, test_men_vecs, 'context', n_gpu=n_gpu) recall_accuracy = {2**i: 0 for i in range(int(math.log(params['recall_k'], 2)) + 1)} recall_idxs = [0.]*params['recall_k'] # Find the most similar entity and k-nn mentions for each mention query for men_query_idx, men_embed in enumerate(tqdm(men_embeds, total=len(men_embeds), desc="Fetching k-NN")): men_embed = np.expand_dims(men_embed, axis=0) dict_type_idx_mapping, men_type_idx_mapping = None, None if use_types: entity_type = mention_data[men_query_idx]['type'] dict_index = dict_indexes[entity_type] men_index = men_indexes[entity_type] dict_type_idx_mapping = dict_idxs_by_type[entity_type] men_type_idx_mapping = men_idxs_by_type[entity_type] # Fetch nearest entity candidate gold_idxs = mention_data[men_query_idx]["label_idxs"][:mention_data[men_query_idx]["n_labels"]] dict_cand_idx, dict_cand_score, recall_idx = get_query_nn( 1, dict_embeds, dict_index, men_embed, searchK=params['recall_k'], gold_idxs=gold_idxs, type_idx_mapping=dict_type_idx_mapping) # Compute recall metric if recall_idx > -1: recall_idxs[recall_idx] += 1. for recall_k in recall_accuracy: if recall_idx < recall_k: recall_accuracy[recall_k] += 1. if not params['only_recall']: # Fetch (k+1) NN mention candidates men_cand_idxs, men_cand_scores = get_query_nn( max_knn + 1, men_embeds, men_index, men_embed, type_idx_mapping=men_type_idx_mapping) # Filter candidates to remove mention query and keep only the top k candidates filter_mask = men_cand_idxs != men_query_idx men_cand_idxs, men_cand_scores = men_cand_idxs[filter_mask][:max_knn], men_cand_scores[filter_mask][:max_knn] # Add edges to the graphs for k in joint_graphs: joint_graph = joint_graphs[k] # Add mention-entity edge joint_graph['rows'] = np.append( joint_graph['rows'], [n_entities+men_query_idx]) # Mentions added at an offset of maximum entities joint_graph['cols'] = np.append( joint_graph['cols'], dict_cand_idx) joint_graph['data'] = np.append( joint_graph['data'], dict_cand_score) if k > 0: # Add mention-mention edges joint_graph['rows'] = np.append( joint_graph['rows'], [n_entities+men_query_idx]*len(men_cand_idxs[:k])) joint_graph['cols'] = np.append( joint_graph['cols'], n_entities+men_cand_idxs[:k]) joint_graph['data'] = np.append( joint_graph['data'], men_cand_scores[:k]) # Compute and print recall metric recall_idx_mode = np.argmax(recall_idxs) recall_idx_mode_prop = recall_idxs[recall_idx_mode]/np.sum(recall_idxs) logger.info(f""" Recall metrics (for {len(men_embeds)} queries): ---------------""") logger.info(f"highest recall idx = {recall_idx_mode} ({recall_idxs[recall_idx_mode]}/{np.sum(recall_idxs)} = {recall_idx_mode_prop})") for recall_k in recall_accuracy: recall_accuracy[recall_k] /= len(men_embeds) logger.info(f"recall@{recall_k} = {recall_accuracy[recall_k]}") if params['only_recall']: exit() # Pickle the graphs print("Saving joint graphs...") with open(graph_path, 'wb') as write_handle: pickle.dump(joint_graphs, write_handle, protocol=pickle.HIGHEST_PROTOCOL) results = [] for k in joint_graphs: print(f"\nGraph (k={k}):") # Partition graph based on cluster-linking constraints partitioned_graph, clusters = partition_graph( joint_graphs[k], n_entities, directed_graph, return_clusters=True) # Infer predictions from clusters result = analyzeClusters(clusters, test_dictionary, mention_data, k) # Store result results.append(result) # Store results output_file_name = os.path.join( output_path, f"eval_results_{__import__('calendar').timegm(__import__('time').gmtime())}") result_overview = { 'n_entities': results[0]['n_entities'], 'n_mentions': results[0]['n_mentions'], 'directed': directed_graph } try: for recall_k in recall_accuracy: result_overview[f'recall@{recall_k}'] = recall_accuracy[recall_k] except: logger.info("Recall data not available since graphs were loaded from disk") for r in results: k = r['knn_mentions'] result_overview[f'accuracy@knn{k}'] = r['accuracy'] logger.info(f"accuracy@knn{k} = {r['accuracy']}") output_file = f'{output_file_name}-{k}.json' with open(output_file, 'w') as f: json.dump(r, f, indent=2) print(f"\nPredictions @knn{k} saved at: {output_file}") with open(f'{output_file_name}.json', 'w') as f: json.dump(result_overview, f, indent=2) print(f"\nPredictions overview saved at: {output_file_name}.json")
def main(params): model_output_path = params["output_path"] if not os.path.exists(model_output_path): os.makedirs(model_output_path) logger = utils.get_logger(params["output_path"]) pickle_src_path = params["pickle_src_path"] if pickle_src_path is None or not os.path.exists(pickle_src_path): pickle_src_path = model_output_path knn = params["knn"] use_types = params["use_types"] # Init model reranker = BiEncoderRanker(params) tokenizer = reranker.tokenizer model = reranker.model device = reranker.device n_gpu = reranker.n_gpu if params["gradient_accumulation_steps"] < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(params["gradient_accumulation_steps"])) # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y` params["train_batch_size"] = (params["train_batch_size"] // params["gradient_accumulation_steps"]) train_batch_size = params["train_batch_size"] eval_batch_size = params["eval_batch_size"] grad_acc_steps = params["gradient_accumulation_steps"] # Fix the random seeds seed = params["seed"] random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if reranker.n_gpu > 0: torch.cuda.manual_seed_all(seed) entity_dictionary_loaded = False entity_dictionary_pkl_path = os.path.join(pickle_src_path, 'entity_dictionary.pickle') if os.path.isfile(entity_dictionary_pkl_path): print("Loading stored processed entity dictionary...") with open(entity_dictionary_pkl_path, 'rb') as read_handle: entity_dictionary = pickle.load(read_handle) entity_dictionary_loaded = True if not params["only_evaluate"]: # Load train data train_tensor_data_pkl_path = os.path.join(pickle_src_path, 'train_tensor_data.pickle') train_processed_data_pkl_path = os.path.join( pickle_src_path, 'train_processed_data.pickle') if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile( train_processed_data_pkl_path): print("Loading stored processed train data...") with open(train_tensor_data_pkl_path, 'rb') as read_handle: train_tensor_data = pickle.load(read_handle) with open(train_processed_data_pkl_path, 'rb') as read_handle: train_processed_data = pickle.load(read_handle) else: train_samples = utils.read_dataset("train", params["data_path"]) if not entity_dictionary_loaded: with open( os.path.join(params["data_path"], 'dictionary.pickle'), 'rb') as read_handle: entity_dictionary = pickle.load(read_handle) # Check if dataset has multiple ground-truth labels mult_labels = "labels" in train_samples[0].keys() if params["filter_unlabeled"]: # Filter samples without gold entities train_samples = list( filter( lambda sample: (len(sample["labels"]) > 0) if mult_labels else (sample["label"] is not None), train_samples)) logger.info("Read %d train samples." % len(train_samples)) # For discovery experiment: Drop entities used in training that were dropped randomly from dev/test set if params["drop_entities"]: assert entity_dictionary_loaded drop_set_pkl_path = os.path.join( pickle_src_path, 'drop_set_mention_data.pickle') with open(drop_set_pkl_path, 'rb') as read_handle: drop_set_data = pickle.load(read_handle) drop_set_mention_gold_cui_idxs = list( map(lambda x: x['label_idxs'][0], drop_set_data)) ents_in_data = np.unique(drop_set_mention_gold_cui_idxs) ent_drop_prop = 0.1 logger.info( f"Dropping {ent_drop_prop*100}% of {len(ents_in_data)} entities found in drop set" ) # Get entity indices to drop n_ents_dropped = int(ent_drop_prop * len(ents_in_data)) rng = np.random.default_rng(seed=17) dropped_ent_idxs = rng.choice(ents_in_data, size=n_ents_dropped, replace=False) # Drop entities from dictionary (subsequent processing will automatically drop corresponding mentions) keep_mask = np.ones(len(entity_dictionary), dtype='bool') keep_mask[dropped_ent_idxs] = False entity_dictionary = np.array(entity_dictionary)[keep_mask] train_processed_data, entity_dictionary, train_tensor_data = data_process.process_mention_data( train_samples, entity_dictionary, tokenizer, params["max_context_length"], params["max_cand_length"], context_key=params["context_key"], multi_label_key="labels" if mult_labels else None, silent=params["silent"], logger=logger, debug=params["debug"], knn=knn, dictionary_processed=entity_dictionary_loaded) print("Saving processed train data...") if not entity_dictionary_loaded: with open(entity_dictionary_pkl_path, 'wb') as write_handle: pickle.dump(entity_dictionary, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(train_tensor_data_pkl_path, 'wb') as write_handle: pickle.dump(train_tensor_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(train_processed_data_pkl_path, 'wb') as write_handle: pickle.dump(train_processed_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) # Store the query mention vectors train_men_vecs = train_tensor_data[:][0] if params["shuffle"]: train_sampler = RandomSampler(train_tensor_data) else: train_sampler = SequentialSampler(train_tensor_data) train_dataloader = DataLoader(train_tensor_data, sampler=train_sampler, batch_size=train_batch_size) # Store the entity dictionary vectors entity_dict_vecs = torch.tensor(list( map(lambda x: x['ids'], entity_dictionary)), dtype=torch.long) # Load eval data valid_tensor_data_pkl_path = os.path.join(pickle_src_path, 'valid_tensor_data.pickle') valid_processed_data_pkl_path = os.path.join( pickle_src_path, 'valid_processed_data.pickle') if os.path.isfile(valid_tensor_data_pkl_path) and os.path.isfile( valid_processed_data_pkl_path): print("Loading stored processed valid data...") with open(valid_tensor_data_pkl_path, 'rb') as read_handle: valid_tensor_data = pickle.load(read_handle) with open(valid_processed_data_pkl_path, 'rb') as read_handle: valid_processed_data = pickle.load(read_handle) else: valid_samples = utils.read_dataset("valid", params["data_path"]) # Check if dataset has multiple ground-truth labels mult_labels = "labels" in valid_samples[0].keys() # Filter samples without gold entities valid_samples = list( filter( lambda sample: (len(sample["labels"]) > 0) if mult_labels else (sample["label"] is not None), valid_samples)) logger.info("Read %d valid samples." % len(valid_samples)) valid_processed_data, _, valid_tensor_data = data_process.process_mention_data( valid_samples, entity_dictionary, tokenizer, params["max_context_length"], params["max_cand_length"], context_key=params["context_key"], multi_label_key="labels" if mult_labels else None, silent=params["silent"], logger=logger, debug=params["debug"], knn=knn, dictionary_processed=True) print("Saving processed valid data...") with open(valid_tensor_data_pkl_path, 'wb') as write_handle: pickle.dump(valid_tensor_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(valid_processed_data_pkl_path, 'wb') as write_handle: pickle.dump(valid_processed_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) # Store the query mention vectors valid_men_vecs = valid_tensor_data[:][0] # valid_sampler = SequentialSampler(valid_tensor_data) # valid_dataloader = DataLoader( # valid_tensor_data, sampler=valid_sampler, batch_size=eval_batch_size # ) if params["only_evaluate"]: evaluate(reranker, entity_dict_vecs, valid_men_vecs, device=device, logger=logger, knn=knn, n_gpu=n_gpu, entity_data=entity_dictionary, query_data=valid_processed_data, silent=params["silent"], use_types=use_types or params["use_types_for_eval"], embed_batch_size=params['embed_batch_size'], force_exact_search=use_types or params["use_types_for_eval"] or params["force_exact_search"], probe_mult_factor=params['probe_mult_factor']) exit() time_start = time.time() utils.write_to_file(os.path.join(model_output_path, "training_params.txt"), str(params)) logger.info("Starting training") logger.info("device: {} n_gpu: {}, data_parallel: {}".format( device, n_gpu, params["data_parallel"])) # Set model to training mode optimizer = get_optimizer(model, params) scheduler = get_scheduler(params, optimizer, len(train_tensor_data), logger) best_epoch_idx = -1 best_score = -1 num_train_epochs = params["num_train_epochs"] init_base_model_run = True if params.get("path_to_model", None) is None else False init_run_pkl_path = os.path.join( pickle_src_path, f'init_run_{"type" if use_types else "notype"}.t7') dict_embed_data = None for epoch_idx in trange(int(num_train_epochs), desc="Epoch"): model.train() torch.cuda.empty_cache() tr_loss = 0 results = None # Check if embeddings and index can be loaded init_run_data_loaded = False if init_base_model_run: if os.path.isfile(init_run_pkl_path): logger.info('Loading init run data') init_run_data = torch.load(init_run_pkl_path) init_run_data_loaded = True load_stored_data = init_base_model_run and init_run_data_loaded # Compute mention and entity embeddings at the start of each epoch if use_types: if load_stored_data: train_dict_embeddings, dict_idxs_by_type = init_run_data[ 'train_dict_embeddings'], init_run_data[ 'dict_idxs_by_type'] train_dict_indexes = data_process.get_index_from_embeds( train_dict_embeddings, dict_idxs_by_type, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) train_men_embeddings, men_idxs_by_type = init_run_data[ 'train_men_embeddings'], init_run_data['men_idxs_by_type'] train_men_indexes = data_process.get_index_from_embeds( train_men_embeddings, men_idxs_by_type, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) else: logger.info('Embedding and indexing') if dict_embed_data is not None: train_dict_embeddings, train_dict_indexes, dict_idxs_by_type = dict_embed_data[ 'dict_embeds'], dict_embed_data[ 'dict_indexes'], dict_embed_data[ 'dict_idxs_by_type'] else: train_dict_embeddings, train_dict_indexes, dict_idxs_by_type = data_process.embed_and_index( reranker, entity_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, corpus=entity_dictionary, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) train_men_embeddings, train_men_indexes, men_idxs_by_type = data_process.embed_and_index( reranker, train_men_vecs, encoder_type="context", n_gpu=n_gpu, corpus=train_processed_data, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) else: if load_stored_data: train_dict_embeddings = init_run_data['train_dict_embeddings'] train_dict_index = data_process.get_index_from_embeds( train_dict_embeddings, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) train_men_embeddings = init_run_data['train_men_embeddings'] train_men_index = data_process.get_index_from_embeds( train_men_embeddings, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) else: logger.info('Embedding and indexing') if dict_embed_data is not None: train_dict_embeddings, train_dict_index = dict_embed_data[ 'dict_embeds'], dict_embed_data['dict_index'] else: train_dict_embeddings, train_dict_index = data_process.embed_and_index( reranker, entity_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) train_men_embeddings, train_men_index = data_process.embed_and_index( reranker, train_men_vecs, encoder_type="context", n_gpu=n_gpu, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) # Save the initial embeddings and index if this is the first run and data isn't persistent if init_base_model_run and not load_stored_data: init_run_data = {} init_run_data['train_dict_embeddings'] = train_dict_embeddings init_run_data['train_men_embeddings'] = train_men_embeddings if use_types: init_run_data['dict_idxs_by_type'] = dict_idxs_by_type init_run_data['men_idxs_by_type'] = men_idxs_by_type # NOTE: Cannot pickle faiss index because it is a SwigPyObject torch.save(init_run_data, init_run_pkl_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) init_base_model_run = False if params["silent"]: iter_ = train_dataloader else: iter_ = tqdm(train_dataloader, desc="Batch") logger.info("Starting KNN search...") if not use_types: _, dict_nns = train_dict_index.search(train_men_embeddings, knn) else: dict_nns = np.zeros((len(train_men_embeddings), knn)) for entity_type in train_men_indexes: men_embeds_by_type = train_men_embeddings[ men_idxs_by_type[entity_type]] _, dict_nns_by_type = train_dict_indexes[entity_type].search( men_embeds_by_type, knn) dict_nns_idxs = np.array( list( map(lambda x: dict_idxs_by_type[entity_type][x], dict_nns_by_type))) for i, idx in enumerate(men_idxs_by_type[entity_type]): dict_nns[idx] = dict_nns_idxs[i] logger.info("Search finished") for step, batch in enumerate(iter_): batch = tuple(t.to(device) for t in batch) context_inputs, candidate_idxs, n_gold, mention_idxs = batch mention_embeddings = train_men_embeddings[mention_idxs.cpu()] if len(mention_embeddings.shape) == 1: mention_embeddings = np.expand_dims(mention_embeddings, axis=0) # context_inputs: Shape: batch x token_len candidate_inputs = np.array( [], dtype=np.int) # Shape: (batch*knn) x token_len label_inputs = torch.tensor( [[1] + [0] * (knn - 1)] * n_gold.sum(), dtype=torch.float32) # Shape: batch(with split rows) x knn context_inputs_split = torch.zeros( (label_inputs.size(0), context_inputs.size(1)), dtype=torch.long) # Shape: batch(with split rows) x token_len # label_inputs = (candidate_idxs >= 0).type(torch.float32) # Shape: batch x knn for i, m_embed in enumerate(mention_embeddings): knn_dict_idxs = dict_nns[mention_idxs[i]] knn_dict_idxs = knn_dict_idxs.astype(np.int64).flatten() gold_idxs = candidate_idxs[i][:n_gold[i]].cpu() for ng, gold_idx in enumerate(gold_idxs): context_inputs_split[i + ng] = context_inputs[i] candidate_inputs = np.concatenate( (candidate_inputs, np.concatenate( ([gold_idx], knn_dict_idxs[~np.isin(knn_dict_idxs, gold_idxs)] ))[:knn])) candidate_inputs = torch.tensor( list( map(lambda x: entity_dict_vecs[x].numpy(), candidate_inputs))).cuda() context_inputs_split = context_inputs_split.cuda() label_inputs = label_inputs.cuda() loss, _ = reranker(context_inputs_split, candidate_inputs, label_inputs, pos_neg_loss=params["pos_neg_loss"]) if grad_acc_steps > 1: loss = loss / grad_acc_steps tr_loss += loss.item() if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0: logger.info("Step {} - epoch {} average loss: {}\n".format( step, epoch_idx, tr_loss / (params["print_interval"] * grad_acc_steps), )) tr_loss = 0 loss.backward() if (step + 1) % grad_acc_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_grad_norm"]) optimizer.step() scheduler.step() optimizer.zero_grad() if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0: logger.info("Evaluation on the development dataset") evaluate(reranker, entity_dict_vecs, valid_men_vecs, device=device, logger=logger, knn=knn, n_gpu=n_gpu, entity_data=entity_dictionary, query_data=valid_processed_data, silent=params["silent"], use_types=use_types or params["use_types_for_eval"], embed_batch_size=params['embed_batch_size'], force_exact_search=use_types or params["use_types_for_eval"] or params["force_exact_search"], probe_mult_factor=params['probe_mult_factor']) model.train() logger.info("\n") logger.info("***** Saving fine-tuned model *****") epoch_output_folder_path = os.path.join(model_output_path, "epoch_{}".format(epoch_idx)) utils.save_model(model, tokenizer, epoch_output_folder_path) logger.info(f"Model saved at {epoch_output_folder_path}") normalized_accuracy, dict_embed_data = evaluate( reranker, entity_dict_vecs, valid_men_vecs, device=device, logger=logger, knn=knn, n_gpu=n_gpu, entity_data=entity_dictionary, query_data=valid_processed_data, silent=params["silent"], use_types=use_types or params["use_types_for_eval"], embed_batch_size=params['embed_batch_size'], force_exact_search=use_types or params["use_types_for_eval"] or params["force_exact_search"], probe_mult_factor=params['probe_mult_factor']) ls = [best_score, normalized_accuracy] li = [best_epoch_idx, epoch_idx] best_score = ls[np.argmax(ls)] best_epoch_idx = li[np.argmax(ls)] logger.info("\n") execution_time = (time.time() - time_start) / 60 utils.write_to_file( os.path.join(model_output_path, "training_time.txt"), "The training took {} minutes\n".format(execution_time), ) logger.info("The training took {} minutes\n".format(execution_time)) # save the best model in the parent_dir logger.info("Best performance in epoch: {}".format(best_epoch_idx)) params["path_to_model"] = os.path.join(model_output_path, "epoch_{}".format(best_epoch_idx)) utils.save_model(reranker.model, tokenizer, model_output_path) logger.info(f"Best model saved at {model_output_path}")
def evaluate(reranker, valid_dict_vecs, valid_men_vecs, device, logger, knn, n_gpu, entity_data, query_data, silent=False, use_types=False, embed_batch_size=768, force_exact_search=False, probe_mult_factor=1): torch.cuda.empty_cache() reranker.model.eval() n_entities = len(valid_dict_vecs) n_mentions = len(valid_men_vecs) joint_graphs = {} max_knn = 4 for k in [0, 1, 2, 4]: joint_graphs[k] = { 'rows': np.array([]), 'cols': np.array([]), 'data': np.array([]), 'shape': (n_entities + n_mentions, n_entities + n_mentions) } if use_types: logger.info("Eval: Dictionary: Embedding and building index") dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index( reranker, valid_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, corpus=entity_data, force_exact_search=force_exact_search, batch_size=embed_batch_size, probe_mult_factor=probe_mult_factor) logger.info("Eval: Queries: Embedding and building index") men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index( reranker, valid_men_vecs, encoder_type="context", n_gpu=n_gpu, corpus=query_data, force_exact_search=force_exact_search, batch_size=embed_batch_size, probe_mult_factor=probe_mult_factor) else: logger.info("Eval: Dictionary: Embedding and building index") dict_embeds, dict_index = data_process.embed_and_index( reranker, valid_dict_vecs, 'candidate', n_gpu=n_gpu, force_exact_search=force_exact_search, batch_size=embed_batch_size, probe_mult_factor=probe_mult_factor) logger.info("Eval: Queries: Embedding and building index") men_embeds, men_index = data_process.embed_and_index( reranker, valid_men_vecs, 'context', n_gpu=n_gpu, force_exact_search=force_exact_search, batch_size=embed_batch_size, probe_mult_factor=probe_mult_factor) logger.info("Eval: Starting KNN search...") # Fetch recall_k (default 16) knn entities for all mentions # Fetch (k+1) NN mention candidates if not use_types: nn_ent_dists, nn_ent_idxs = dict_index.search(men_embeds, 1) nn_men_dists, nn_men_idxs = men_index.search(men_embeds, max_knn + 1) else: nn_ent_idxs = np.zeros((len(men_embeds), 1)) nn_ent_dists = np.zeros((len(men_embeds), 1), dtype='float64') nn_men_idxs = np.zeros((len(men_embeds), max_knn + 1)) nn_men_dists = np.zeros((len(men_embeds), max_knn + 1), dtype='float64') for entity_type in men_indexes: men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type]] nn_ent_dists_by_type, nn_ent_idxs_by_type = dict_indexes[ entity_type].search(men_embeds_by_type, 1) nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[ entity_type].search(men_embeds_by_type, max_knn + 1) nn_ent_idxs_by_type = np.array( list( map(lambda x: dict_idxs_by_type[entity_type][x], nn_ent_idxs_by_type))) nn_men_idxs_by_type = np.array( list( map(lambda x: men_idxs_by_type[entity_type][x], nn_men_idxs_by_type))) for i, idx in enumerate(men_idxs_by_type[entity_type]): nn_ent_idxs[idx] = nn_ent_idxs_by_type[i] nn_ent_dists[idx] = nn_ent_dists_by_type[i] nn_men_idxs[idx] = nn_men_idxs_by_type[i] nn_men_dists[idx] = nn_men_dists_by_type[i] logger.info("Eval: Search finished") logger.info('Eval: Building graphs') for men_query_idx, men_embed in enumerate( tqdm(men_embeds, total=len(men_embeds), desc="Eval: Building graphs")): # Get nearest entity candidate dict_cand_idx = nn_ent_idxs[men_query_idx][0] dict_cand_score = nn_ent_dists[men_query_idx][0] # Filter candidates to remove mention query and keep only the top k candidates men_cand_idxs = nn_men_idxs[men_query_idx] men_cand_scores = nn_men_dists[men_query_idx] filter_mask = men_cand_idxs != men_query_idx men_cand_idxs, men_cand_scores = men_cand_idxs[ filter_mask][:max_knn], men_cand_scores[filter_mask][:max_knn] # Add edges to the graphs for k in joint_graphs: joint_graph = joint_graphs[k] # Add mention-entity edge joint_graph['rows'] = np.append( joint_graph['rows'], [n_entities + men_query_idx ]) # Mentions added at an offset of maximum entities joint_graph['cols'] = np.append(joint_graph['cols'], dict_cand_idx) joint_graph['data'] = np.append(joint_graph['data'], dict_cand_score) if k > 0: # Add mention-mention edges joint_graph['rows'] = np.append(joint_graph['rows'], [n_entities + men_query_idx] * len(men_cand_idxs[:k])) joint_graph['cols'] = np.append(joint_graph['cols'], n_entities + men_cand_idxs[:k]) joint_graph['data'] = np.append(joint_graph['data'], men_cand_scores[:k]) max_eval_acc = -1. for k in joint_graphs: logger.info(f"\nEval: Graph (k={k}):") # Partition graph based on cluster-linking constraints partitioned_graph, clusters = eval_cluster_linking.partition_graph( joint_graphs[k], n_entities, directed=True, return_clusters=True) # Infer predictions from clusters result = eval_cluster_linking.analyzeClusters(clusters, entity_data, query_data, k) acc = float(result['accuracy'].split(' ')[0]) max_eval_acc = max(acc, max_eval_acc) logger.info(f"Eval: accuracy for graph@k={k}: {acc}%") logger.info(f"Eval: Best accuracy: {max_eval_acc}%") return max_eval_acc, { 'dict_embeds': dict_embeds, 'dict_indexes': dict_indexes, 'dict_idxs_by_type': dict_idxs_by_type } if use_types else { 'dict_embeds': dict_embeds, 'dict_index': dict_index }
def evaluate_ind_pred(reranker, valid_dataloader, valid_dict_vecs, params, device, logger, knn, n_gpu, entity_data, query_data, use_types=False, embed_batch_size=768): reranker.model.eval() knn = max( 16, 2 * knn ) # Accomodate the approximate-nature of the knn procedure by retrieving more samples and then filtering iter_ = valid_dataloader if params["silent"] else tqdm(valid_dataloader, desc="Evaluation") eval_accuracy = 0.0 nb_eval_examples = 0 nb_eval_steps = 0 if not use_types: valid_dict_embeddings, valid_dict_index = data_process.embed_and_index( reranker, valid_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, batch_size=embed_batch_size) else: valid_dict_embeddings, valid_dict_indexes, dict_idxs_by_type = data_process.embed_and_index( reranker, valid_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, corpus=entity_data, batch_size=embed_batch_size) for step, batch in enumerate(iter_): batch = tuple(t.to(device) for t in batch) context_inputs, candidate_idxs, n_gold, mention_idxs = batch with torch.no_grad(): mention_embeddings = reranker.encode_context(context_inputs) # context_inputs: Shape: batch x token_len candidate_inputs = np.array( [], dtype=np.int) # Shape: (batch*knn) x token_len label_inputs = torch.zeros( (context_inputs.shape[0], knn), dtype=torch.float32) # Shape: batch x knn for i, m_embed in enumerate(mention_embeddings): if use_types: entity_type = query_data[mention_idxs[i]]['type'] valid_dict_index = valid_dict_indexes[entity_type] _, knn_dict_idxs = valid_dict_index.search( np.expand_dims(m_embed, axis=0), knn) knn_dict_idxs = knn_dict_idxs.astype(np.int64).flatten() if use_types: # Map type-specific indices to the entire dictionary knn_dict_idxs = list( map(lambda x: dict_idxs_by_type[entity_type][x], knn_dict_idxs)) gold_idxs = candidate_idxs[i][:n_gold[i]].cpu() candidate_inputs = np.concatenate( (candidate_inputs, knn_dict_idxs)) label_inputs[i] = torch.tensor( [1 if nn in gold_idxs else 0 for nn in knn_dict_idxs]) candidate_inputs = torch.tensor( list( map(lambda x: valid_dict_vecs[x].numpy(), candidate_inputs))).cuda() context_inputs = context_inputs.cuda() label_inputs = label_inputs.cuda() logits = reranker(context_inputs, candidate_inputs, label_inputs, only_logits=True) logits = logits.detach().cpu().numpy() tmp_eval_accuracy = int( torch.sum(label_inputs[np.arange(label_inputs.shape[0]), np.argmax(logits, axis=1)] == 1)) eval_accuracy += tmp_eval_accuracy nb_eval_examples += context_inputs.size(0) nb_eval_steps += 1 normalized_eval_accuracy = eval_accuracy / nb_eval_examples logger.info("Eval accuracy: %.5f" % normalized_eval_accuracy) return normalized_eval_accuracy
def main(params): output_path = params["output_path"] if not os.path.exists(output_path): os.makedirs(output_path) logger = utils.get_logger(params["output_path"], 'log-eval') pickle_src_path = params["pickle_src_path"] if pickle_src_path is None or not os.path.exists(pickle_src_path): pickle_src_path = output_path embed_data_path = params["embed_data_path"] if embed_data_path is None or not os.path.exists(embed_data_path): embed_data_path = output_path # Init model reranker = BiEncoderRanker(params) reranker.model.eval() tokenizer = reranker.tokenizer n_gpu = reranker.n_gpu knn = params["knn"] use_types = params["use_types"] within_doc = params["within_doc"] data_split = params["data_split"] # Default = "test" # Load test data test_samples = None entity_dictionary_loaded = False test_dictionary_pkl_path = os.path.join(pickle_src_path, 'test_dictionary.pickle') test_tensor_data_pkl_path = os.path.join(pickle_src_path, 'test_tensor_data.pickle') test_mention_data_pkl_path = os.path.join(pickle_src_path, 'test_mention_data.pickle') if params['transductive']: train_tensor_data_pkl_path = os.path.join(pickle_src_path, 'train_tensor_data.pickle') train_mention_data_pkl_path = os.path.join( pickle_src_path, 'train_mention_data.pickle') if os.path.isfile(test_dictionary_pkl_path): print("Loading stored processed entity dictionary...") with open(test_dictionary_pkl_path, 'rb') as read_handle: test_dictionary = pickle.load(read_handle) entity_dictionary_loaded = True if os.path.isfile(test_tensor_data_pkl_path) and os.path.isfile( test_mention_data_pkl_path): print("Loading stored processed test data...") with open(test_tensor_data_pkl_path, 'rb') as read_handle: test_tensor_data = pickle.load(read_handle) with open(test_mention_data_pkl_path, 'rb') as read_handle: mention_data = pickle.load(read_handle) else: test_samples = utils.read_dataset(data_split, params["data_path"]) if not entity_dictionary_loaded: with open(os.path.join(params["data_path"], 'dictionary.pickle'), 'rb') as read_handle: test_dictionary = pickle.load(read_handle) # Check if dataset has multiple ground-truth labels mult_labels = "labels" in test_samples[0].keys() if params["filter_unlabeled"]: # Filter samples without gold entities test_samples = list( filter( lambda sample: (len(sample["labels"]) > 0) if mult_labels else (sample["label"] is not None), test_samples)) logger.info("Read %d test samples." % len(test_samples)) mention_data, test_dictionary, test_tensor_data = data_process.process_mention_data( test_samples, test_dictionary, tokenizer, params["max_context_length"], params["max_cand_length"], multi_label_key="labels" if mult_labels else None, context_key=params["context_key"], silent=params["silent"], logger=logger, debug=params["debug"], knn=knn, dictionary_processed=entity_dictionary_loaded) print("Saving processed test data...") if not entity_dictionary_loaded: with open(test_dictionary_pkl_path, 'wb') as write_handle: pickle.dump(test_dictionary, write_handle, protocol=pickle.HIGHEST_PROTOCOL) entity_dictionary_loaded = True with open(test_tensor_data_pkl_path, 'wb') as write_handle: pickle.dump(test_tensor_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(test_mention_data_pkl_path, 'wb') as write_handle: pickle.dump(mention_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) # Store test dictionary token ids test_dict_vecs = torch.tensor(list(map(lambda x: x['ids'], test_dictionary)), dtype=torch.long) # Store test mention token ids test_men_vecs = test_tensor_data[:][0] n_entities = len(test_dict_vecs) n_mentions = len(test_tensor_data) if within_doc: if test_samples is None: test_samples, _ = read_data(data_split, params, logger) test_context_doc_ids = [s['context_doc_id'] for s in test_samples] if params["transductive"]: if os.path.isfile(train_tensor_data_pkl_path) and os.path.isfile( train_mention_data_pkl_path): print("Loading stored processed train data...") with open(train_tensor_data_pkl_path, 'rb') as read_handle: train_tensor_data = pickle.load(read_handle) with open(train_mention_data_pkl_path, 'rb') as read_handle: train_mention_data = pickle.load(read_handle) else: train_samples = utils.read_dataset('train', params["data_path"]) # Check if dataset has multiple ground-truth labels mult_labels = "labels" in train_samples[0].keys() logger.info("Read %d test samples." % len(test_samples)) train_mention_data, _, train_tensor_data = data_process.process_mention_data( train_samples, test_dictionary, tokenizer, params["max_context_length"], params["max_cand_length"], multi_label_key="labels" if mult_labels else None, context_key=params["context_key"], silent=params["silent"], logger=logger, debug=params["debug"], knn=knn, dictionary_processed=entity_dictionary_loaded) print("Saving processed train data...") with open(train_tensor_data_pkl_path, 'wb') as write_handle: pickle.dump(train_tensor_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) with open(train_mention_data_pkl_path, 'wb') as write_handle: pickle.dump(train_mention_data, write_handle, protocol=pickle.HIGHEST_PROTOCOL) # Store train mention token ids train_men_vecs = train_tensor_data[:][0] n_mentions += len(train_tensor_data) n_train_mentions = len(train_tensor_data) # Values of k to run the evaluation against knn_vals = [0] + [2**i for i in range(int(math.log(knn, 2)) + 1)] # Store the maximum evaluation k max_knn = knn_vals[-1] time_start = time.time() # Check if graphs are already built graph_path = os.path.join(output_path, 'graphs.pickle') if not params['only_recall'] and os.path.isfile(graph_path): print("Loading stored joint graphs...") with open(graph_path, 'rb') as read_handle: joint_graphs = pickle.load(read_handle) else: # Initialize graphs to store mention-mention and mention-entity similarity score edges; # Keyed on k, the number of nearest mentions retrieved joint_graphs = {} for k in knn_vals: joint_graphs[k] = { 'rows': np.array([]), 'cols': np.array([]), 'data': np.array([]), 'shape': (n_entities + n_mentions, n_entities + n_mentions) } # Check and load stored embedding data embed_data_path = os.path.join(embed_data_path, 'embed_data.t7') embed_data = None if os.path.isfile(embed_data_path): embed_data = torch.load(embed_data_path) if use_types: if embed_data is not None: logger.info('Loading stored embeddings and computing indexes') dict_embeds = embed_data['dict_embeds'] if 'dict_idxs_by_type' in embed_data: dict_idxs_by_type = embed_data['dict_idxs_by_type'] else: dict_idxs_by_type = data_process.get_idxs_by_type( test_dictionary) dict_indexes = data_process.get_index_from_embeds( dict_embeds, dict_idxs_by_type, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) men_embeds = embed_data['men_embeds'] if 'men_idxs_by_type' in embed_data: men_idxs_by_type = embed_data['men_idxs_by_type'] else: men_idxs_by_type = data_process.get_idxs_by_type( mention_data) men_indexes = data_process.get_index_from_embeds( men_embeds, men_idxs_by_type, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) else: logger.info("Dictionary: Embedding and building index") dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index( reranker, test_dict_vecs, encoder_type="candidate", n_gpu=n_gpu, corpus=test_dictionary, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) logger.info("Queries: Embedding and building index") vecs = test_men_vecs men_data = mention_data if params['transductive']: vecs = torch.cat((train_men_vecs, vecs), dim=0) men_data = train_mention_data + mention_data men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index( reranker, vecs, encoder_type="context", n_gpu=n_gpu, corpus=men_data, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) else: if embed_data is not None: logger.info('Loading stored embeddings and computing indexes') dict_embeds = embed_data['dict_embeds'] dict_index = data_process.get_index_from_embeds( dict_embeds, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) men_embeds = embed_data['men_embeds'] men_index = data_process.get_index_from_embeds( men_embeds, force_exact_search=params['force_exact_search'], probe_mult_factor=params['probe_mult_factor']) else: logger.info("Dictionary: Embedding and building index") dict_embeds, dict_index = data_process.embed_and_index( reranker, test_dict_vecs, 'candidate', n_gpu=n_gpu, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) logger.info("Queries: Embedding and building index") vecs = test_men_vecs if params['transductive']: vecs = torch.cat((train_men_vecs, vecs), dim=0) men_embeds, men_index = data_process.embed_and_index( reranker, vecs, 'context', n_gpu=n_gpu, force_exact_search=params['force_exact_search'], batch_size=params['embed_batch_size'], probe_mult_factor=params['probe_mult_factor']) # Save computed embedding data if not loaded from disk if embed_data is None: embed_data = {} embed_data['dict_embeds'] = dict_embeds embed_data['men_embeds'] = men_embeds if use_types: embed_data['dict_idxs_by_type'] = dict_idxs_by_type embed_data['men_idxs_by_type'] = men_idxs_by_type # NOTE: Cannot pickle faiss index because it is a SwigPyObject torch.save(embed_data, embed_data_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) recall_accuracy = { 2**i: 0 for i in range(int(math.log(params['recall_k'], 2)) + 1) } recall_idxs = [0.] * params['recall_k'] logger.info("Starting KNN search...") # Fetch recall_k (default 16) knn entities for all mentions # Fetch (k+1) NN mention candidates if not use_types: _men_embeds = men_embeds if params['transductive']: _men_embeds = _men_embeds[n_train_mentions:] nn_ent_dists, nn_ent_idxs = dict_index.search( _men_embeds, params['recall_k']) n_mens_to_fetch = len(_men_embeds) if within_doc else max_knn + 1 nn_men_dists, nn_men_idxs = men_index.search( _men_embeds, n_mens_to_fetch) else: query_len = len(men_embeds) - (n_train_mentions if params['transductive'] else 0) nn_ent_idxs = np.zeros((query_len, params['recall_k'])) nn_ent_dists = np.zeros((query_len, params['recall_k']), dtype='float64') nn_men_idxs = -1 * np.ones((query_len, query_len), dtype=int) nn_men_dists = -1 * np.ones( (query_len, query_len), dtype='float64') for entity_type in men_indexes: men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type][ men_idxs_by_type[entity_type] >= n_train_mentions]] if params[ 'transductive'] else men_embeds[ men_idxs_by_type[entity_type]] nn_ent_dists_by_type, nn_ent_idxs_by_type = dict_indexes[ entity_type].search(men_embeds_by_type, params['recall_k']) nn_ent_idxs_by_type = np.array( list( map(lambda x: dict_idxs_by_type[entity_type][x], nn_ent_idxs_by_type))) n_mens_to_fetch = len( men_embeds_by_type) if within_doc else max_knn + 1 nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[ entity_type].search( men_embeds_by_type, min(n_mens_to_fetch, len(men_embeds_by_type))) nn_men_idxs_by_type = np.array( list( map(lambda x: men_idxs_by_type[entity_type][x], nn_men_idxs_by_type))) i = -1 for idx in men_idxs_by_type[entity_type]: if params['transductive']: idx -= n_train_mentions if idx < 0: continue i += 1 nn_ent_idxs[idx] = nn_ent_idxs_by_type[i] nn_ent_dists[idx] = nn_ent_dists_by_type[i] nn_men_idxs[idx][:len(nn_men_idxs_by_type[i] )] = nn_men_idxs_by_type[i] nn_men_dists[idx][:len(nn_men_dists_by_type[i] )] = nn_men_dists_by_type[i] logger.info("Search finished") logger.info('Building graphs') # Find the most similar entity and k-nn mentions for each mention query for idx in range(len(nn_ent_idxs)): # Get nearest entity candidate dict_cand_idx = nn_ent_idxs[idx][0] dict_cand_score = nn_ent_dists[idx][0] # Compute recall metric gold_idxs = mention_data[idx][ "label_idxs"][:mention_data[idx]["n_labels"]] recall_idx = np.argwhere(nn_ent_idxs[idx] == gold_idxs[0]) if len(recall_idx) != 0: recall_idx = int(recall_idx) recall_idxs[recall_idx] += 1. for recall_k in recall_accuracy: if recall_idx < recall_k: recall_accuracy[recall_k] += 1. if not params['only_recall']: filter_mask_neg1 = nn_men_idxs[idx] != -1 men_cand_idxs = nn_men_idxs[idx][filter_mask_neg1] men_cand_scores = nn_men_dists[idx][filter_mask_neg1] if within_doc: men_cand_idxs, wd_mask = filter_by_context_doc_id( men_cand_idxs, test_context_doc_ids[idx], test_context_doc_ids, return_numpy=True) men_cand_scores = men_cand_scores[wd_mask] # Filter candidates to remove mention query and keep only the top k candidates filter_mask = men_cand_idxs != idx men_cand_idxs, men_cand_scores = men_cand_idxs[ filter_mask][:max_knn], men_cand_scores[ filter_mask][:max_knn] if params['transductive']: idx += n_train_mentions # Add edges to the graphs for k in joint_graphs: joint_graph = joint_graphs[k] # Add mention-entity edge joint_graph['rows'] = np.append( joint_graph['rows'], [n_entities + idx ]) # Mentions added at an offset of maximum entities joint_graph['cols'] = np.append(joint_graph['cols'], dict_cand_idx) joint_graph['data'] = np.append(joint_graph['data'], dict_cand_score) if k > 0: # Add mention-mention edges joint_graph['rows'] = np.append( joint_graph['rows'], [n_entities + idx] * len(men_cand_idxs[:k])) joint_graph['cols'] = np.append( joint_graph['cols'], n_entities + men_cand_idxs[:k]) joint_graph['data'] = np.append( joint_graph['data'], men_cand_scores[:k]) if params['transductive']: # Add positive infinity mention-entity edges from training queries to labeled entities for idx, train_men in enumerate(train_mention_data): dict_cand_idx = train_men["label_idxs"][0] for k in joint_graphs: joint_graph = joint_graphs[k] joint_graph['rows'] = np.append( joint_graph['rows'], [n_entities + idx ]) # Mentions added at an offset of maximum entities joint_graph['cols'] = np.append(joint_graph['cols'], dict_cand_idx) joint_graph['data'] = np.append(joint_graph['data'], float('inf')) # Compute and print recall metric recall_idx_mode = np.argmax(recall_idxs) recall_idx_mode_prop = recall_idxs[recall_idx_mode] / np.sum( recall_idxs) logger.info(f""" Recall metrics (for {len(nn_ent_idxs)} queries): ---------------""") logger.info( f"highest recall idx = {recall_idx_mode} ({recall_idxs[recall_idx_mode]}/{np.sum(recall_idxs)} = {recall_idx_mode_prop})" ) for recall_k in recall_accuracy: recall_accuracy[recall_k] /= len(nn_ent_idxs) logger.info(f"recall@{recall_k} = {recall_accuracy[recall_k]}") if params['only_recall']: exit() # Pickle the graphs print("Saving joint graphs...") with open(graph_path, 'wb') as write_handle: pickle.dump(joint_graphs, write_handle, protocol=pickle.HIGHEST_PROTOCOL) if params['only_embed_and_build']: logger.info(f"Saved embedding data at: {embed_data_path}") logger.info(f"Saved graphs at: {graph_path}") exit() graph_mode = params.get('graph_mode', None) result_overview = { 'n_entities': n_entities, 'n_mentions': n_mentions - (n_train_mentions if params['transductive'] else 0) } results = {} if graph_mode is None or graph_mode not in ['directed', 'undirected']: results['directed'] = [] results['undirected'] = [] else: results[graph_mode] = [] knn_fetch_time = time.time() - time_start graph_processing_time = time.time() n_graphs_processed = 0. for mode in results: print(f'\nEvaluation mode: {mode.upper()}') for k in joint_graphs: if k <= knn: print(f"\nGraph (k={k}):") # Partition graph based on cluster-linking constraints partitioned_graph, clusters = partition_graph( joint_graphs[k], n_entities, mode == 'directed', return_clusters=True) # Infer predictions from clusters result = analyzeClusters( clusters, test_dictionary, mention_data, k, n_train_mentions if params['transductive'] else 0) # Store result results[mode].append(result) n_graphs_processed += 1 avg_graph_processing_time = (time.time() - graph_processing_time) / n_graphs_processed avg_per_graph_time = (knn_fetch_time + avg_graph_processing_time) / 60 execution_time = (time.time() - time_start) / 60 # Store results output_file_name = os.path.join( output_path, f"eval_results_{__import__('calendar').timegm(__import__('time').gmtime())}" ) try: for recall_k in recall_accuracy: result_overview[f'recall@{recall_k}'] = recall_accuracy[recall_k] except: logger.info( "Recall data not available since graphs were loaded from disk") for mode in results: mode_results = results[mode] result_overview[mode] = {} for r in mode_results: k = r['knn_mentions'] result_overview[mode][f'accuracy@knn{k}'] = r['accuracy'] logger.info(f"{mode} accuracy@knn{k} = {r['accuracy']}") output_file = f'{output_file_name}-{mode}-{k}.json' with open(output_file, 'w') as f: json.dump(r, f, indent=2) print( f"\nPredictions ({mode}) @knn{k} saved at: {output_file}") with open(f'{output_file_name}.json', 'w') as f: json.dump(result_overview, f, indent=2) print(f"\nPredictions overview saved at: {output_file_name}.json") logger.info("\nThe avg. per graph evaluation time is {} minutes\n".format( avg_per_graph_time)) logger.info( "\nThe total evaluation took {} minutes\n".format(execution_time))