def get_similarity(args, test_dataset, top_n): logging.info("***** Get Similarity *****") ent_list = list(sorted(list(test_dataset.ent2idx.keys()))) rel_list = list(sorted(list(test_dataset.rel2idx.keys()))) id2ent = test_dataset.id2ent # ========= get all embeddings ========== save_dir = args.model_name_or_path if dir_exists( args.model_name_or_path) else args.output_dir save_path = os.path.join(save_dir, "saved_emb_mat.np") emb_mat = get_emb_mat(save_dir, save_path, ent_list, rel_list, test_dataset, args) ent2emb, ent_rel2emb = assign_emb2elements(ent_list, rel_list, emb_mat) emb_list = [] for i in range(len(ent_list)): emb_list.append(ent2emb[id2ent[i]]) emb_list = torch.stack(emb_list, dim=0).to(args.device) similarity_score_mtx = [] # similarity_index_mtx = [] # if args.get_cosine_similarity: for i in tqdm(range(len(ent_list))): line_similarity = torch.cosine_similarity( torch.stack([emb_list[i]] * len(emb_list)), emb_list, dim=-1).detach().cpu().numpy() line_simi_index = np.argsort( -line_similarity) # np.argsort sorts elements from small to large line_simi_score = line_similarity[line_simi_index] # similarity_index_mtx.append(line_simi_index[:top_n]) similarity_score_mtx.append(line_simi_score[:top_n]) # similarity_index_mtx = np.array(similarity_index_mtx) similarity_score_mtx = np.array(similarity_score_mtx) if dir_exists(save_dir): # np.save(os.path.join(save_dir, "similarity_index_mtx"), similarity_index_mtx) np.save(os.path.join(save_dir, "similarity_score_mtx"), similarity_score_mtx) # else: # raise NotImplementedError return similarity_score_mtx
def get_ent_emb(args, test_dataset): logging.info("***** Get all entities embedding *****") ent_list = list(sorted(list(test_dataset.ent2idx.keys()))) rel_list = list(sorted(list(test_dataset.rel2idx.keys()))) id2ent = test_dataset.id2ent # ========= get all embeddings ========== save_dir = args.model_name_or_path if dir_exists( args.model_name_or_path) else args.output_dir save_path = os.path.join(save_dir, "saved_emb_mat.np") emb_mat = get_emb_mat(save_dir, save_path, ent_list, rel_list, test_dataset, args) ent2emb, ent_rel2emb = assign_emb2elements(ent_list, rel_list, emb_mat) emb_list = [] for i in range(len(ent_list)): emb_list.append(ent2emb[id2ent[i]]) emb_list = torch.stack(emb_list, dim=0) if dir_exists(save_dir): torch.save(emb_list, os.path.join(save_dir, "ent_emb.pkl")) logging.info("***** Finished *****")
def get_dev_candidate(args, raw_examples, dataset_list, model, verbose=True): logging.info("***** Running Prediction*****") model.eval() # get the last one (i.e., test) to make use if its useful functions and data standard_dataset = dataset_list[-1] ents = set() g_subj2objs = collections.defaultdict(lambda: collections.defaultdict(set)) g_obj2subjs = collections.defaultdict(lambda: collections.defaultdict(set)) for _ds in dataset_list: for _raw_ex in _ds.raw_examples: _head, _rel, _tail = _raw_ex ents.add(_head) ents.add(_tail) g_subj2objs[_head][_rel].add(_tail) g_obj2subjs[_tail][_rel].add(_head) ent_list = list(sorted(ents)) # rel_list = list(sorted(standard_dataset.rel_list)) rel_set = set() for triplet in raw_examples: _, r, _ = triplet rel_set.add(r) rel_list = list(sorted(list(rel_set))) # ========= get all embeddings ========== print("get all embeddings") save_dir = args.model_name_or_path save_path = os.path.join(args.model_name_or_path, "saved_dev_emb_mat.np") new_dev_path = os.path.join("./data/" + args.dataset + "/new_dev.dict") if dir_exists(save_dir) and file_exists(save_path): print("\tload from file") emb_mat = torch.load(save_path) else: print("\tget all ids") input_ids_list, mask_ids_list, segment_ids_list = [], [], [] for _ent in tqdm(ent_list): for _idx_r, _rel in enumerate(rel_list): head_ids, rel_ids, tail_ids = standard_dataset.convert_raw_example_to_features( [_ent, _rel, _ent], method="4") head_ids, rel_ids, tail_ids = head_ids[1:-1], rel_ids[ 1:-1], tail_ids[1:-1] # truncate max_ent_len = standard_dataset.max_seq_length - 3 - len( rel_ids) head_ids = head_ids[:max_ent_len] tail_ids = tail_ids[:max_ent_len] src_input_ids = [standard_dataset._cls_id] + head_ids + [ standard_dataset._sep_id ] + rel_ids + [standard_dataset._sep_id] src_mask_ids = [1] * len(src_input_ids) src_segment_ids = [0] * (len(head_ids) + 2) + [1] * (len(rel_ids) + 1) if _idx_r == 0: tgt_input_ids = [standard_dataset._cls_id ] + tail_ids + [standard_dataset._sep_id] tgt_mask_ids = [1] * len(tgt_input_ids) tgt_segment_ids = [0] * (len(tail_ids) + 2) input_ids_list.append(tgt_input_ids) mask_ids_list.append(tgt_mask_ids) segment_ids_list.append(tgt_segment_ids) input_ids_list.append(src_input_ids) mask_ids_list.append(src_mask_ids) segment_ids_list.append(src_segment_ids) # # padding max_len = max(len(_e) for _e in input_ids_list) assert max_len <= standard_dataset.max_seq_length input_ids_list = [ _e + [standard_dataset._pad_id] * (max_len - len(_e)) for _e in input_ids_list ] mask_ids_list = [ _e + [0] * (max_len - len(_e)) for _e in mask_ids_list ] segment_ids_list = [ _e + [0] * (max_len - len(_e)) for _e in segment_ids_list ] # # dataset enc_dataset = TensorDataset( torch.tensor(input_ids_list, dtype=torch.long), torch.tensor(mask_ids_list, dtype=torch.long), torch.tensor(segment_ids_list, dtype=torch.long), ) enc_dataloader = DataLoader(enc_dataset, sampler=SequentialSampler(enc_dataset), batch_size=args.eval_batch_size * 2) print("\tget all emb via model") embs_list = [] rep_torch_dtype = None for batch in tqdm(enc_dataloader, desc="entity embedding", disable=(not verbose)): batch = tuple(t.to(args.device) for t in batch) _input_ids, _mask_ids, _segment_ids = batch with torch.no_grad(): embs = model.encoder(_input_ids, attention_mask=_mask_ids, token_type_ids=_segment_ids) if rep_torch_dtype is None: rep_torch_dtype = embs.dtype embs = embs.detach().cpu() embs_list.append(embs) emb_mat = torch.cat(embs_list, dim=0).contiguous() assert emb_mat.shape[0] == len(input_ids_list) # save emb_mat if dir_exists(save_dir): torch.save(emb_mat, save_path) # # assign to ent assert len(ent_list) * (1 + len(rel_list)) == emb_mat.shape[0] ent_rel2emb = collections.defaultdict(dict) # ent + r (h + r ent2emb = dict() # ent (t ptr_row = 0 for _ent in ent_list: for _idx_r, _rel in enumerate(rel_list): if _idx_r == 0: ent2emb[_ent] = emb_mat[ptr_row] ptr_row += 1 ent_rel2emb[_ent][_rel] = emb_mat[ptr_row] ptr_row += 1 # ========== get candidates ========== # * begin to get hit new_dev_dict = collections.defaultdict(dict) for _idx_ex, _triplet in enumerate( tqdm(raw_examples, desc="get candidates")): _head, _rel, _tail = _triplet head_ent_list = [] tail_ent_list = [] # head corrupt _pos_head_ents = g_obj2subjs[_tail][_rel] _neg_head_ents = ents - _pos_head_ents head_ent_list.append(_head) # positive example head_ent_list.extend(_neg_head_ents) # negative examples tail_ent_list.extend([_tail] * (1 + len(_neg_head_ents))) split_idx = len(head_ent_list) # tail corrupt _pos_tail_ents = g_subj2objs[_head][_rel] _neg_tail_ents = ents - _pos_tail_ents head_ent_list.extend([_head] * (1 + len(_neg_tail_ents))) tail_ent_list.append(_tail) # positive example tail_ent_list.extend(_neg_tail_ents) # negative examples triplet_list = list([_h, _rel, _t] for _h, _t in zip(head_ent_list, tail_ent_list)) # build dataset rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list] rep_tgt_list = [ent2emb[_t] for _, _, _t in triplet_list] all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device) all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device) local_scores_list = [] sim_batch_size = args.eval_batch_size * 8 for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size): _rep_src, _rep_tgt = all_rep_src[ _idx_r:_idx_r + sim_batch_size], all_rep_tgt[_idx_r:_idx_r + sim_batch_size] with torch.no_grad(): logits = model.classifier(_rep_src, _rep_tgt) logits = torch.softmax(logits, dim=-1) local_scores = logits.detach().cpu().numpy()[:, 1] local_scores_list.append(local_scores) scores = np.concatenate(local_scores_list, axis=0) # left left_scores = scores[:split_idx] heads_corrupt_idx = get_triplet_candidate(left_scores) heads_corrupt = [head_ent_list[i] for i in heads_corrupt_idx] right_scores = scores[split_idx:] tails_corrupt_idx = get_triplet_candidate(right_scores) tails_corrupt = [ tail_ent_list[i + split_idx] for i in tails_corrupt_idx ] new_dev_dict[tuple(_triplet)]["heads_corrupt"] = heads_corrupt new_dev_dict[tuple(_triplet)]["tails_corrupt"] = tails_corrupt torch.save(new_dev_dict, new_dev_path)
def collect_case(args, raw_examples, dataset_list, model, verbose=True): logging.info("***** Running Getting Cases*****") model.eval() # get the last one (i.e., test) to make use if its useful functions and data standard_dataset = dataset_list[-1] ents = set() g_subj2objs = collections.defaultdict(lambda: collections.defaultdict(set)) g_obj2subjs = collections.defaultdict(lambda: collections.defaultdict(set)) for _ds in dataset_list: for _raw_ex in _ds.raw_examples: _head, _rel, _tail = _raw_ex ents.add(_head) ents.add(_tail) g_subj2objs[_head][_rel].add(_tail) g_obj2subjs[_tail][_rel].add(_head) ent_list = list(sorted(ents)) rel_list = list(sorted(standard_dataset.rel_list)) # ========= get all embeddings ========== save_dir = args.model_name_or_path if dir_exists( args.model_name_or_path) else args.output_dir save_path = os.path.join(save_dir, "saved_emb_mat.np") case_save_path = os.path.join(save_dir, "cases_alone.txt") dict_save_path = join(save_dir, 'cases_alone.dict') emb_mat = get_emb_mat(save_dir, save_path, ent_list, rel_list, standard_dataset, args, model) ent2emb, ent_rel2emb = assign_emb2elements(ent_list, rel_list, emb_mat) # ========== get logits and distances ========== results_dict = collections.defaultdict(dict) ent2text_dict = dataset_list[0].ent2text for ent_id, ent_text in ent2text_dict.items(): #################NOTE######################### ent2text_dict[ent_id] = ent_text.split(",")[0] # ent2text_dict[ent_id] = ent_text for _idx_ex, _triplet in enumerate(tqdm(raw_examples, desc="get cases")): _head, _rel, _tail = _triplet head_ent_list = [] tail_ent_list = [] # head corrupt _pos_head_ents = g_obj2subjs[_tail][_rel] _neg_head_ents = ents - _pos_head_ents head_ent_list.append(_head) # positive example head_ent_list.extend(_neg_head_ents) # negative examples tail_ent_list.extend([_tail] * (1 + len(_neg_head_ents))) split_idx = len(head_ent_list) # tail corrupt _pos_tail_ents = g_subj2objs[_head][_rel] _neg_tail_ents = ents - _pos_tail_ents head_ent_list.extend([_head] * (1 + len(_neg_tail_ents))) tail_ent_list.append(_tail) # positive example tail_ent_list.extend(_neg_tail_ents) # negative examples triplet_list = list([_h, _rel, _t] for _h, _t in zip(head_ent_list, tail_ent_list)) # all triples to be verified for a test sample # build dataset rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list] rep_tgt_list = [ent2emb[_t] for _, _rel, _t in triplet_list] all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device) all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device) local_logits_list = [] local_distances_list = [] sim_batch_size = args.eval_batch_size * 8 for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size): _rep_src, _rep_tgt = all_rep_src[ _idx_r:_idx_r + sim_batch_size], all_rep_tgt[_idx_r:_idx_r + sim_batch_size] with torch.no_grad(): logits = model.classifier(_rep_src, _rep_tgt) distances = model.distance_metric_fn(_rep_src, _rep_tgt) logits = logits.detach().cpu().numpy() distances = distances.detach().cpu().numpy() local_logits_list.append(logits) local_distances_list.append(distances) sample_logits_list = np.concatenate(local_logits_list, axis=0) sample_distances_list = np.concatenate(local_distances_list, axis=0) sample_logits_list = torch.from_numpy(sample_logits_list) scores = torch.softmax(sample_logits_list, dim=-1)[:, 1] scores = scores.detach().cpu().numpy() # left left_scores = scores[:split_idx] left_sort_idxs = np.argsort(-left_scores) left_rank = np.where(left_sort_idxs == 0)[0][0] + 1 left_sort_idxs = left_sort_idxs[:20] left_rk_inf = [] left_rk_inf.append([ent2text_dict[_head], left_rank, left_scores[0]]) for i, idx in enumerate(left_sort_idxs): corrupt_rank = np.where(left_sort_idxs == idx)[0][0] + 1 corrupt_text = ent2text_dict[head_ent_list[idx]] left_rk_inf.append([corrupt_text, corrupt_rank, left_scores[idx]]) # right right_scores = scores[split_idx:] right_sort_idxs = np.argsort(-right_scores) right_rank = np.where(right_sort_idxs == 0)[0][0] + 1 right_sort_idxs = right_sort_idxs[:20] right_rk_inf = [] right_rk_inf.append( [ent2text_dict[_tail], right_rank, right_scores[0]]) for i, idx in enumerate(right_sort_idxs): corrupt_rank = np.where(right_sort_idxs == idx)[0][0] + 1 corrupt_text = ent2text_dict[tail_ent_list[split_idx + idx]] right_rk_inf.append( [corrupt_text, corrupt_rank, right_scores[idx]]) _head_text = ent2text_dict[_head] _tail_text = ent2text_dict[_tail] _triplet_text = tuple([_head_text, _rel, _tail_text]) # add infor to results_dict {triples:{head:[pos ranking,others[entity_text, score]], tail:[]},...} results_dict[_triplet_text] = { "head": left_rk_inf, "tail": right_rk_inf } with open(case_save_path, 'a', encoding='utf-8') as f: f.write(str([_head_text, _rel, _tail_text]) + '\n') f.write("head:" + str(left_rk_inf) + '\n') f.write("tail:" + str(right_rk_inf) + '\n\n') logging.info("Get cases in text finished") if dir_exists(save_dir): torch.save(results_dict, dict_save_path) logging.info("Get cases in dict finished")
def get_model_dataset(args, model, dataset_list, data_type='test', top_n=1000, verbose=True): model.eval() if data_type == 'train': raw_examples = dataset_list[0].raw_examples if data_type == 'dev': raw_examples = dataset_list[1].raw_examples elif data_type == 'test': raw_examples = dataset_list[2].raw_examples dataset = dataset_list[2] # get the last one (i.e., test) to make use if its useful functions and data ent_list = list(sorted(list(dataset.ent2idx.keys()))) rel_list = list(sorted(list(dataset.rel2idx.keys()))) g_subj2objs = collections.defaultdict(lambda: collections.defaultdict(set)) g_obj2subjs = collections.defaultdict(lambda: collections.defaultdict(set)) # prepare to remove the true triples for _ds in dataset_list: for _raw_ex in _ds.raw_examples: _head, _rel, _tail = _raw_ex g_subj2objs[_head][_rel].add(_tail) g_obj2subjs[_tail][_rel].add(_head) save_dir = args.model_name_or_path if dir_exists( args.model_name_or_path) else args.output_dir save_path = os.path.join(save_dir, "saved_emb_mat.np") emb_mat = get_emb_mat(save_dir, save_path, ent_list, rel_list, dataset, args, model) ent2emb, ent_rel2emb = assign_emb2elements(ent_list, rel_list, emb_mat) if data_type == 'train': head_scores_path = os.path.join( save_dir, data_type + "_head_topN_scores.list") # "_head_scores.list" tail_scores_path = os.path.join( save_dir, data_type + "_tail_topN_scores.list") # "_tail_scores.list" elif data_type == 'dev' or data_type == 'test': full_head_scores_path = os.path.join( save_dir, data_type + "_head_full_scores.list") full_tail_scores_path = os.path.join( save_dir, data_type + "_tail_full_scores.list") # ========== get ranked logits score and corresponding index ========== full_head_scores = [] full_tail_scores = [] head_scores = [] tail_scores = [] split_idx = len(dataset.id2ent_list) id2ent_list = dataset.id2ent_list head_triple_idx_list = [] tail_triple_idx_list = [] for _idx_ex, _triplet in enumerate( tqdm(raw_examples, desc="Get_" + data_type + "_datasets")): _head, _rel, _tail = _triplet head_ent_list = id2ent_list tail_ent_list = [_tail] * split_idx head_ent_list = head_ent_list + [_head] * split_idx tail_ent_list = tail_ent_list + id2ent_list triplet_list = list([_h, _rel, _t] for _h, _t in zip(head_ent_list, tail_ent_list)) # build dataset rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list] rep_tgt_list = [ent2emb[_t] for _, _rel, _t in triplet_list] all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device) all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device) local_logits_list = [] sim_batch_size = args.eval_batch_size * 8 for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size): _rep_src, _rep_tgt = all_rep_src[ _idx_r:_idx_r + sim_batch_size], all_rep_tgt[_idx_r:_idx_r + sim_batch_size] with torch.no_grad(): logits = model.classifier(_rep_src, _rep_tgt) logits = torch.softmax(logits, dim=-1) local_scores = logits.detach().cpu().numpy()[:, 1] local_logits_list.append(local_scores) sample_logits_list = np.concatenate(local_logits_list, axis=0) full_head_scores.append( [np.array(dataset.ent2idx[_head]), sample_logits_list[:split_idx]]) full_tail_scores.append( [np.array(dataset.ent2idx[_tail]), sample_logits_list[split_idx:]]) head_logits_list = sample_logits_list[:split_idx] tail_logits_list = sample_logits_list[split_idx:] pos_head_idx, pos_tail_idx = dataset.ent2idx[_head], dataset.ent2idx[ _tail] head_pos_score = head_logits_list[pos_head_idx] head_sort_idxs = np.argsort(-head_logits_list) head_pos_rank = np.where(head_sort_idxs == pos_head_idx)[0][0] + 1 head_top_ranked_score = head_logits_list[head_sort_idxs[:top_n]] head_top_ranked_idx = head_sort_idxs[:top_n] tail_pos_score = tail_logits_list[pos_tail_idx] tail_sort_idxs = np.argsort(-tail_logits_list) tail_pos_rank = np.where(tail_sort_idxs == pos_tail_idx)[0][0] + 1 tail_top_ranked_score = tail_logits_list[tail_sort_idxs[:top_n]] tail_top_ranked_idx = tail_sort_idxs[:top_n] # For each triple head/tail info: [[pos_idx, pos_score, pos_ranking],[top-N score], [top-N idx] if head_pos_rank <= top_n: head_scores.append([ head_pos_rank - 1, head_top_ranked_score, head_top_ranked_idx.astype(np.int32) ]) head_triple_idx_list.append(_idx_ex) if tail_pos_rank <= top_n: tail_scores.append([ tail_pos_rank - 1, tail_top_ranked_score, tail_top_ranked_idx.astype(np.int32) ]) tail_triple_idx_list.append(_idx_ex) if dir_exists(save_dir): if data_type == 'train': torch.save(head_scores, head_scores_path) torch.save(tail_scores, tail_scores_path) torch.save(head_triple_idx_list, join(save_dir, data_type + '_head_triple_idx.list')) torch.save(tail_triple_idx_list, join(save_dir, data_type + '_tail_triple_idx.list')) elif data_type == 'dev' or data_type == 'test': torch.save(full_head_scores, full_head_scores_path) torch.save(full_tail_scores, full_tail_scores_path) print("Get scores and datasets finished") emb_list = [] for _i in range(len(ent_list)): id2ent = dataset.id2ent emb_list.append(ent2emb[id2ent[_i]]) emb_list = torch.stack(emb_list, dim=0) if dir_exists(save_dir): torch.save(emb_list, os.path.join(save_dir, "ent_emb.pkl")) logging.info("***** Save all entities embedding finished. *****")
def get_scores(args, raw_examples, dataset, model, data_type=None, verbose=True): save_dir = args.model_name_or_path if dir_exists( args.model_name_or_path) else args.output_dir save_path = os.path.join(save_dir, "saved_emb_mat.np") head_scores_path = os.path.join(save_dir, data_type + "_head_full_scores.list") tail_scores_path = os.path.join(save_dir, data_type + "_tail_full_scores.list") if file_exists(head_scores_path) and file_exists(tail_scores_path): logging.info("Load head and tail mode scores") head_scores = torch.load(head_scores_path) tail_scores = torch.load(tail_scores_path) return head_scores, tail_scores model.eval() ent_list = list(sorted(list(dataset.ent2idx.keys()))) rel_list = list(sorted(list(dataset.rel2idx.keys()))) logging.info("Load all embeddings") emb_mat = get_emb_mat(save_dir, save_path, ent_list, rel_list, dataset, args, model) ent2emb, ent_rel2emb = assign_emb2elements(ent_list, rel_list, emb_mat) # ========== get ranked logits score ========== head_scores = [] tail_scores = [] split_idx = len(dataset.id2ent_list) id2ent_list = dataset.id2ent_list for _idx_ex, _triplet in enumerate( tqdm(raw_examples, desc="get_" + data_type + "_scores")): _head, _rel, _tail = _triplet head_ent_list = id2ent_list tail_ent_list = [_tail] * split_idx head_ent_list = head_ent_list + [_head] * split_idx tail_ent_list = tail_ent_list + id2ent_list triplet_list = list([_h, _rel, _t] for _h, _t in zip(head_ent_list, tail_ent_list)) # build dataset rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list] rep_tgt_list = [ent2emb[_t] for _, _rel, _t in triplet_list] all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device) all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device) local_logits_list = [] sim_batch_size = args.eval_batch_size * 8 for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size): _rep_src, _rep_tgt = all_rep_src[ _idx_r:_idx_r + sim_batch_size], all_rep_tgt[_idx_r:_idx_r + sim_batch_size] with torch.no_grad(): logits = model.classifier(_rep_src, _rep_tgt) logits = torch.softmax(logits, dim=-1) local_scores = logits.detach().cpu().numpy()[:, 1] local_logits_list.append(local_scores) sample_logits_list = np.concatenate(local_logits_list, axis=0) head_scores.append( [np.array(dataset.ent2idx[_head]), sample_logits_list[:split_idx]]) tail_scores.append( [np.array(dataset.ent2idx[_tail]), sample_logits_list[split_idx:]]) if dir_exists(save_dir): torch.save(head_scores, head_scores_path) torch.save(tail_scores, tail_scores_path) logging.info("Get scores finished") return head_scores, tail_scores
def get_emb_mat(save_dir, save_path, ent_list, rel_list, dataset, args, model=None, verbose=True): if dir_exists(save_dir) and file_exists(save_path): logging.info("load from file") emb_mat = torch.load(save_path) else: logging.info("get all ids") input_ids_list, mask_ids_list, segment_ids_list = [], [], [] for _ent in tqdm(ent_list): for _idx_r, _rel in enumerate(rel_list): head_ids, rel_ids, tail_ids = dataset.convert_raw_example_to_features( [_ent, _rel, _ent], method="4") head_ids, rel_ids, tail_ids = head_ids[1:-1], rel_ids[ 1:-1], tail_ids[1:-1] # truncate max_ent_len = dataset.max_seq_length - 3 - len(rel_ids) head_ids = head_ids[:max_ent_len] tail_ids = tail_ids[:max_ent_len] src_input_ids = [dataset._cls_id] + head_ids + [ dataset._sep_id ] + rel_ids + [dataset._sep_id] src_mask_ids = [1] * len(src_input_ids) src_segment_ids = [0] * (len(head_ids) + 2) + [1] * (len(rel_ids) + 1) if _idx_r == 0: tgt_input_ids = [dataset._cls_id ] + tail_ids + [dataset._sep_id] tgt_mask_ids = [1] * len(tgt_input_ids) tgt_segment_ids = [0] * (len(tail_ids) + 2) input_ids_list.append(tgt_input_ids) mask_ids_list.append(tgt_mask_ids) segment_ids_list.append(tgt_segment_ids) input_ids_list.append(src_input_ids) mask_ids_list.append(src_mask_ids) segment_ids_list.append(src_segment_ids) # # padding max_len = max(len(_e) for _e in input_ids_list) assert max_len <= dataset.max_seq_length input_ids_list = [ _e + [dataset._pad_id] * (max_len - len(_e)) for _e in input_ids_list ] mask_ids_list = [ _e + [0] * (max_len - len(_e)) for _e in mask_ids_list ] segment_ids_list = [ _e + [0] * (max_len - len(_e)) for _e in segment_ids_list ] # # dataset enc_dataset = TensorDataset( torch.tensor(input_ids_list, dtype=torch.long), torch.tensor(mask_ids_list, dtype=torch.long), torch.tensor(segment_ids_list, dtype=torch.long), ) enc_dataloader = DataLoader(enc_dataset, sampler=SequentialSampler(enc_dataset), batch_size=args.eval_batch_size * 2) logging.info("get all emb via model") embs_list = [] for batch in tqdm(enc_dataloader, desc="entity embedding", disable=(not verbose)): batch = tuple(t.to(args.device) for t in batch) _input_ids, _mask_ids, _segment_ids = batch with torch.no_grad(): embs = model.encoder(_input_ids, attention_mask=_mask_ids, token_type_ids=_segment_ids) embs = embs.detach().cpu() embs_list.append(embs) emb_mat = torch.cat(embs_list, dim=0).contiguous() assert emb_mat.shape[0] == len(input_ids_list) # save emb_mat if dir_exists(save_dir): torch.save(emb_mat, save_path) return emb_mat
def predict_NELL(args, raw_examples, dataset_list, model, verbose=True): logging.info("***** Running Prediction*****") model.eval() # get the last one (i.e., test) to make use if its useful functions and data standard_dataset = dataset_list[-1] ents = set() g_subj2objs = collections.defaultdict(lambda: collections.defaultdict(set)) g_obj2subjs = collections.defaultdict(lambda: collections.defaultdict(set)) for _ds in dataset_list: for _raw_ex in _ds.raw_examples: _head, _rel, _tail = _raw_ex ents.add(_head) ents.add(_tail) g_subj2objs[_head][_rel].add(_tail) g_obj2subjs[_tail][_rel].add(_head) ent_list = list(sorted(ents)) # rel_list = list(sorted(standard_dataset.rel_list)) rel_set = set() for triplet in raw_examples: _, r, _ = triplet rel_set.add(r) rel_list = list(sorted(list(rel_set))) # ========= get all embeddings ========== print("get all embeddings") save_dir = args.model_name_or_path if dir_exists(args.model_name_or_path) else args.output_dir save_path = os.path.join(save_dir, "saved_emb_mat.np") if dir_exists(save_dir) and file_exists(save_path): print("\tload from file") emb_mat = torch.load(save_path) else: print("\tget all ids") input_ids_list, mask_ids_list, segment_ids_list = [], [], [] for _ent in tqdm(ent_list): for _idx_r, _rel in enumerate(rel_list): head_ids, rel_ids, tail_ids = standard_dataset.convert_raw_example_to_features( [_ent, _rel, _ent], method="4") head_ids, rel_ids, tail_ids = head_ids[1:-1], rel_ids[1:-1], tail_ids[1:-1] # truncate max_ent_len = standard_dataset.max_seq_length - 3 - len(rel_ids) head_ids = head_ids[:max_ent_len] tail_ids = tail_ids[:max_ent_len] src_input_ids = [standard_dataset._cls_id] + head_ids + [standard_dataset._sep_id] + rel_ids + [ standard_dataset._sep_id] src_mask_ids = [1] * len(src_input_ids) src_segment_ids = [0] * (len(head_ids) + 2) + [1] * (len(rel_ids) + 1) if _idx_r == 0: tgt_input_ids = [standard_dataset._cls_id] + tail_ids + [standard_dataset._sep_id] tgt_mask_ids = [1] * len(tgt_input_ids) tgt_segment_ids = [0] * (len(tail_ids) + 2) input_ids_list.append(tgt_input_ids) mask_ids_list.append(tgt_mask_ids) segment_ids_list.append(tgt_segment_ids) input_ids_list.append(src_input_ids) mask_ids_list.append(src_mask_ids) segment_ids_list.append(src_segment_ids) # # padding max_len = max(len(_e) for _e in input_ids_list) assert max_len <= standard_dataset.max_seq_length input_ids_list = [_e + [standard_dataset._pad_id] * (max_len - len(_e)) for _e in input_ids_list] mask_ids_list = [_e + [0] * (max_len - len(_e)) for _e in mask_ids_list] segment_ids_list = [_e + [0] * (max_len - len(_e)) for _e in segment_ids_list] # # dataset enc_dataset = TensorDataset( torch.tensor(input_ids_list, dtype=torch.long), torch.tensor(mask_ids_list, dtype=torch.long), torch.tensor(segment_ids_list, dtype=torch.long), ) enc_dataloader = DataLoader( enc_dataset, sampler=SequentialSampler(enc_dataset) , batch_size=args.eval_batch_size*2) print("\tget all emb via model") embs_list = [] for batch in tqdm(enc_dataloader, desc="entity embedding", disable=(not verbose)): batch = tuple(t.to(args.device) for t in batch) _input_ids, _mask_ids, _segment_ids = batch with torch.no_grad(): embs = model.encoder(_input_ids, attention_mask=_mask_ids, token_type_ids=_segment_ids) embs = embs.detach().cpu() embs_list.append(embs) emb_mat = torch.cat(embs_list, dim=0).contiguous() assert emb_mat.shape[0] == len(input_ids_list) # save emb_mat if dir_exists(save_dir): torch.save(emb_mat, save_path) # # assign to ent assert len(ent_list) *(1+len(rel_list)) == emb_mat.shape[0] ent_rel2emb = collections.defaultdict(dict) ent2emb = dict() ptr_row = 0 for _ent in ent_list: for _idx_r, _rel in enumerate(rel_list): if _idx_r == 0: ent2emb[_ent] = emb_mat[ptr_row] ptr_row += 1 ent_rel2emb[_ent][_rel] = emb_mat[ptr_row] ptr_row += 1 # ========= run link prediction ========== # * begin to get hit ranks_left, ranks_right, ranks = [], [], [] hits_left, hits_right, hits = [], [], [] top_ten_hit_count = 0 top_five_hit_count = 0 top_one_hit_count = 0 for i in range(10): hits_left.append([]) hits_right.append([]) hits.append([]) for _idx_ex, _triplet in enumerate(tqdm(raw_examples, desc="evaluating")): _head, _rel, _tail = _triplet head_ent_list = [] tail_ent_list = [] # tail corrupt _pos_tail_ents = g_subj2objs[_head][_rel] _neg_tail_ents = ents - _pos_tail_ents _neg_tail_ents = [_ent for _ent in _neg_tail_ents if _ent in standard_dataset.type_dict[_rel]["tail"]] head_ent_list.extend([_head] * (1 + len(_neg_tail_ents))) tail_ent_list.append(_tail) # positive example tail_ent_list.extend(_neg_tail_ents) # negative examples triplet_list = list([_h, _rel, _t] for _h, _t in zip(head_ent_list, tail_ent_list)) # build dataset rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list] rep_tgt_list = [ent2emb[_t] for _, _rel, _t in triplet_list] all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device) all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device) local_scores_list = [] sim_batch_size = args.eval_batch_size * 8 if args.cls_method == "dis": for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size): _rep_src, _rep_tgt = all_rep_src[_idx_r: _idx_r + sim_batch_size], all_rep_tgt[ _idx_r: _idx_r + sim_batch_size] with torch.no_grad(): distances = model.distance_metric_fn(_rep_src, _rep_tgt).to(torch.float32) local_scores = - distances local_scores = local_scores.detach().cpu().numpy() local_scores_list.append(local_scores) elif args.cls_method == "cls": for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size): _rep_src, _rep_tgt = all_rep_src[_idx_r: _idx_r + sim_batch_size], all_rep_tgt[ _idx_r: _idx_r + sim_batch_size] with torch.no_grad(): logits = model.classifier(_rep_src, _rep_tgt).to(torch.float32) logits = torch.softmax(logits, dim=-1) local_scores = logits.detach().cpu().numpy()[:, 1] local_scores_list.append(local_scores) scores = np.concatenate(local_scores_list, axis=0) # right right_scores = scores right_rank = safe_ranking(right_scores) ranks_right.append(right_rank) ranks.append(right_rank) # log top_ten_hit_count += int(right_rank <= 10) top_five_hit_count += int(right_rank <= 5) top_one_hit_count += int(right_rank <= 1) if (_idx_ex + 1) % 10 == 0: logger.info("hit@1 until now: {}".format(top_one_hit_count * 1.0 / len(ranks))) logger.info("hit@5 until now: {}".format(top_five_hit_count * 1.0 / len(ranks))) logger.info("hit@10 until now: {}".format(top_ten_hit_count * 1.0 / len(ranks))) # hits for hits_level in range(10): if right_rank <= hits_level + 1: hits[hits_level].append(1.0) hits_right[hits_level].append(1.0) else: hits[hits_level].append(0.0) hits_right[hits_level].append(0.0) if verbose: for i in [0, 4, 9]: logger.info('Hits right @{0}: {1}'.format(i + 1, np.mean(hits_right[i]))) logger.info('Mean rank right: {0}'.format(np.mean(ranks_right))) logger.info('Mean reciprocal rank right: {0}'.format(np.mean(1. / np.array(ranks_right)))) with open(join(args.output_dir, "link_prediction_metrics.txt"), "w", encoding="utf-8") as fp: for i in [0, 4, 9]: fp.write('Hits right @{0}: {1}\n'.format(i + 1, np.mean(hits_right[i]))) fp.write('Mean rank right: {0}\n'.format(np.mean(ranks_right))) fp.write('Mean reciprocal rank right: {0}\n'.format(np.mean(1. / np.array(ranks_right)))) print("save finished!") tuple_ranks = [[int(_l), int(_r)] for _l, _r in zip(ranks_left, ranks_right)] return tuple_ranks