def get_dev_candidate(args, raw_examples, dataset_list, model, verbose=True): logging.info("***** Running Prediction*****") model.eval() # get the last one (i.e., test) to make use if its useful functions and data standard_dataset = dataset_list[-1] ents = set() g_subj2objs = collections.defaultdict(lambda: collections.defaultdict(set)) g_obj2subjs = collections.defaultdict(lambda: collections.defaultdict(set)) for _ds in dataset_list: for _raw_ex in _ds.raw_examples: _head, _rel, _tail = _raw_ex ents.add(_head) ents.add(_tail) g_subj2objs[_head][_rel].add(_tail) g_obj2subjs[_tail][_rel].add(_head) ent_list = list(sorted(ents)) # rel_list = list(sorted(standard_dataset.rel_list)) rel_set = set() for triplet in raw_examples: _, r, _ = triplet rel_set.add(r) rel_list = list(sorted(list(rel_set))) # ========= get all embeddings ========== print("get all embeddings") save_dir = args.model_name_or_path save_path = os.path.join(args.model_name_or_path, "saved_dev_emb_mat.np") new_dev_path = os.path.join("./data/" + args.dataset + "/new_dev.dict") if dir_exists(save_dir) and file_exists(save_path): print("\tload from file") emb_mat = torch.load(save_path) else: print("\tget all ids") input_ids_list, mask_ids_list, segment_ids_list = [], [], [] for _ent in tqdm(ent_list): for _idx_r, _rel in enumerate(rel_list): head_ids, rel_ids, tail_ids = standard_dataset.convert_raw_example_to_features( [_ent, _rel, _ent], method="4") head_ids, rel_ids, tail_ids = head_ids[1:-1], rel_ids[ 1:-1], tail_ids[1:-1] # truncate max_ent_len = standard_dataset.max_seq_length - 3 - len( rel_ids) head_ids = head_ids[:max_ent_len] tail_ids = tail_ids[:max_ent_len] src_input_ids = [standard_dataset._cls_id] + head_ids + [ standard_dataset._sep_id ] + rel_ids + [standard_dataset._sep_id] src_mask_ids = [1] * len(src_input_ids) src_segment_ids = [0] * (len(head_ids) + 2) + [1] * (len(rel_ids) + 1) if _idx_r == 0: tgt_input_ids = [standard_dataset._cls_id ] + tail_ids + [standard_dataset._sep_id] tgt_mask_ids = [1] * len(tgt_input_ids) tgt_segment_ids = [0] * (len(tail_ids) + 2) input_ids_list.append(tgt_input_ids) mask_ids_list.append(tgt_mask_ids) segment_ids_list.append(tgt_segment_ids) input_ids_list.append(src_input_ids) mask_ids_list.append(src_mask_ids) segment_ids_list.append(src_segment_ids) # # padding max_len = max(len(_e) for _e in input_ids_list) assert max_len <= standard_dataset.max_seq_length input_ids_list = [ _e + [standard_dataset._pad_id] * (max_len - len(_e)) for _e in input_ids_list ] mask_ids_list = [ _e + [0] * (max_len - len(_e)) for _e in mask_ids_list ] segment_ids_list = [ _e + [0] * (max_len - len(_e)) for _e in segment_ids_list ] # # dataset enc_dataset = TensorDataset( torch.tensor(input_ids_list, dtype=torch.long), torch.tensor(mask_ids_list, dtype=torch.long), torch.tensor(segment_ids_list, dtype=torch.long), ) enc_dataloader = DataLoader(enc_dataset, sampler=SequentialSampler(enc_dataset), batch_size=args.eval_batch_size * 2) print("\tget all emb via model") embs_list = [] rep_torch_dtype = None for batch in tqdm(enc_dataloader, desc="entity embedding", disable=(not verbose)): batch = tuple(t.to(args.device) for t in batch) _input_ids, _mask_ids, _segment_ids = batch with torch.no_grad(): embs = model.encoder(_input_ids, attention_mask=_mask_ids, token_type_ids=_segment_ids) if rep_torch_dtype is None: rep_torch_dtype = embs.dtype embs = embs.detach().cpu() embs_list.append(embs) emb_mat = torch.cat(embs_list, dim=0).contiguous() assert emb_mat.shape[0] == len(input_ids_list) # save emb_mat if dir_exists(save_dir): torch.save(emb_mat, save_path) # # assign to ent assert len(ent_list) * (1 + len(rel_list)) == emb_mat.shape[0] ent_rel2emb = collections.defaultdict(dict) # ent + r (h + r ent2emb = dict() # ent (t ptr_row = 0 for _ent in ent_list: for _idx_r, _rel in enumerate(rel_list): if _idx_r == 0: ent2emb[_ent] = emb_mat[ptr_row] ptr_row += 1 ent_rel2emb[_ent][_rel] = emb_mat[ptr_row] ptr_row += 1 # ========== get candidates ========== # * begin to get hit new_dev_dict = collections.defaultdict(dict) for _idx_ex, _triplet in enumerate( tqdm(raw_examples, desc="get candidates")): _head, _rel, _tail = _triplet head_ent_list = [] tail_ent_list = [] # head corrupt _pos_head_ents = g_obj2subjs[_tail][_rel] _neg_head_ents = ents - _pos_head_ents head_ent_list.append(_head) # positive example head_ent_list.extend(_neg_head_ents) # negative examples tail_ent_list.extend([_tail] * (1 + len(_neg_head_ents))) split_idx = len(head_ent_list) # tail corrupt _pos_tail_ents = g_subj2objs[_head][_rel] _neg_tail_ents = ents - _pos_tail_ents head_ent_list.extend([_head] * (1 + len(_neg_tail_ents))) tail_ent_list.append(_tail) # positive example tail_ent_list.extend(_neg_tail_ents) # negative examples triplet_list = list([_h, _rel, _t] for _h, _t in zip(head_ent_list, tail_ent_list)) # build dataset rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list] rep_tgt_list = [ent2emb[_t] for _, _, _t in triplet_list] all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device) all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device) local_scores_list = [] sim_batch_size = args.eval_batch_size * 8 for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size): _rep_src, _rep_tgt = all_rep_src[ _idx_r:_idx_r + sim_batch_size], all_rep_tgt[_idx_r:_idx_r + sim_batch_size] with torch.no_grad(): logits = model.classifier(_rep_src, _rep_tgt) logits = torch.softmax(logits, dim=-1) local_scores = logits.detach().cpu().numpy()[:, 1] local_scores_list.append(local_scores) scores = np.concatenate(local_scores_list, axis=0) # left left_scores = scores[:split_idx] heads_corrupt_idx = get_triplet_candidate(left_scores) heads_corrupt = [head_ent_list[i] for i in heads_corrupt_idx] right_scores = scores[split_idx:] tails_corrupt_idx = get_triplet_candidate(right_scores) tails_corrupt = [ tail_ent_list[i + split_idx] for i in tails_corrupt_idx ] new_dev_dict[tuple(_triplet)]["heads_corrupt"] = heads_corrupt new_dev_dict[tuple(_triplet)]["tails_corrupt"] = tails_corrupt torch.save(new_dev_dict, new_dev_path)
def get_emb_mat(save_dir, save_path, ent_list, rel_list, dataset, args, model=None, verbose=True): if dir_exists(save_dir) and file_exists(save_path): logging.info("load from file") emb_mat = torch.load(save_path) else: logging.info("get all ids") input_ids_list, mask_ids_list, segment_ids_list = [], [], [] for _ent in tqdm(ent_list): for _idx_r, _rel in enumerate(rel_list): head_ids, rel_ids, tail_ids = dataset.convert_raw_example_to_features( [_ent, _rel, _ent], method="4") head_ids, rel_ids, tail_ids = head_ids[1:-1], rel_ids[ 1:-1], tail_ids[1:-1] # truncate max_ent_len = dataset.max_seq_length - 3 - len(rel_ids) head_ids = head_ids[:max_ent_len] tail_ids = tail_ids[:max_ent_len] src_input_ids = [dataset._cls_id] + head_ids + [ dataset._sep_id ] + rel_ids + [dataset._sep_id] src_mask_ids = [1] * len(src_input_ids) src_segment_ids = [0] * (len(head_ids) + 2) + [1] * (len(rel_ids) + 1) if _idx_r == 0: tgt_input_ids = [dataset._cls_id ] + tail_ids + [dataset._sep_id] tgt_mask_ids = [1] * len(tgt_input_ids) tgt_segment_ids = [0] * (len(tail_ids) + 2) input_ids_list.append(tgt_input_ids) mask_ids_list.append(tgt_mask_ids) segment_ids_list.append(tgt_segment_ids) input_ids_list.append(src_input_ids) mask_ids_list.append(src_mask_ids) segment_ids_list.append(src_segment_ids) # # padding max_len = max(len(_e) for _e in input_ids_list) assert max_len <= dataset.max_seq_length input_ids_list = [ _e + [dataset._pad_id] * (max_len - len(_e)) for _e in input_ids_list ] mask_ids_list = [ _e + [0] * (max_len - len(_e)) for _e in mask_ids_list ] segment_ids_list = [ _e + [0] * (max_len - len(_e)) for _e in segment_ids_list ] # # dataset enc_dataset = TensorDataset( torch.tensor(input_ids_list, dtype=torch.long), torch.tensor(mask_ids_list, dtype=torch.long), torch.tensor(segment_ids_list, dtype=torch.long), ) enc_dataloader = DataLoader(enc_dataset, sampler=SequentialSampler(enc_dataset), batch_size=args.eval_batch_size * 2) logging.info("get all emb via model") embs_list = [] for batch in tqdm(enc_dataloader, desc="entity embedding", disable=(not verbose)): batch = tuple(t.to(args.device) for t in batch) _input_ids, _mask_ids, _segment_ids = batch with torch.no_grad(): embs = model.encoder(_input_ids, attention_mask=_mask_ids, token_type_ids=_segment_ids) embs = embs.detach().cpu() embs_list.append(embs) emb_mat = torch.cat(embs_list, dim=0).contiguous() assert emb_mat.shape[0] == len(input_ids_list) # save emb_mat if dir_exists(save_dir): torch.save(emb_mat, save_path) return emb_mat
def get_scores(args, raw_examples, dataset, model, data_type=None, verbose=True): save_dir = args.model_name_or_path if dir_exists( args.model_name_or_path) else args.output_dir save_path = os.path.join(save_dir, "saved_emb_mat.np") head_scores_path = os.path.join(save_dir, data_type + "_head_full_scores.list") tail_scores_path = os.path.join(save_dir, data_type + "_tail_full_scores.list") if file_exists(head_scores_path) and file_exists(tail_scores_path): logging.info("Load head and tail mode scores") head_scores = torch.load(head_scores_path) tail_scores = torch.load(tail_scores_path) return head_scores, tail_scores model.eval() ent_list = list(sorted(list(dataset.ent2idx.keys()))) rel_list = list(sorted(list(dataset.rel2idx.keys()))) logging.info("Load all embeddings") emb_mat = get_emb_mat(save_dir, save_path, ent_list, rel_list, dataset, args, model) ent2emb, ent_rel2emb = assign_emb2elements(ent_list, rel_list, emb_mat) # ========== get ranked logits score ========== head_scores = [] tail_scores = [] split_idx = len(dataset.id2ent_list) id2ent_list = dataset.id2ent_list for _idx_ex, _triplet in enumerate( tqdm(raw_examples, desc="get_" + data_type + "_scores")): _head, _rel, _tail = _triplet head_ent_list = id2ent_list tail_ent_list = [_tail] * split_idx head_ent_list = head_ent_list + [_head] * split_idx tail_ent_list = tail_ent_list + id2ent_list triplet_list = list([_h, _rel, _t] for _h, _t in zip(head_ent_list, tail_ent_list)) # build dataset rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list] rep_tgt_list = [ent2emb[_t] for _, _rel, _t in triplet_list] all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device) all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device) local_logits_list = [] sim_batch_size = args.eval_batch_size * 8 for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size): _rep_src, _rep_tgt = all_rep_src[ _idx_r:_idx_r + sim_batch_size], all_rep_tgt[_idx_r:_idx_r + sim_batch_size] with torch.no_grad(): logits = model.classifier(_rep_src, _rep_tgt) logits = torch.softmax(logits, dim=-1) local_scores = logits.detach().cpu().numpy()[:, 1] local_logits_list.append(local_scores) sample_logits_list = np.concatenate(local_logits_list, axis=0) head_scores.append( [np.array(dataset.ent2idx[_head]), sample_logits_list[:split_idx]]) tail_scores.append( [np.array(dataset.ent2idx[_tail]), sample_logits_list[split_idx:]]) if dir_exists(save_dir): torch.save(head_scores, head_scores_path) torch.save(tail_scores, tail_scores_path) logging.info("Get scores finished") return head_scores, tail_scores
def predict_NELL(args, raw_examples, dataset_list, model, verbose=True): logging.info("***** Running Prediction*****") model.eval() # get the last one (i.e., test) to make use if its useful functions and data standard_dataset = dataset_list[-1] ents = set() g_subj2objs = collections.defaultdict(lambda: collections.defaultdict(set)) g_obj2subjs = collections.defaultdict(lambda: collections.defaultdict(set)) for _ds in dataset_list: for _raw_ex in _ds.raw_examples: _head, _rel, _tail = _raw_ex ents.add(_head) ents.add(_tail) g_subj2objs[_head][_rel].add(_tail) g_obj2subjs[_tail][_rel].add(_head) ent_list = list(sorted(ents)) # rel_list = list(sorted(standard_dataset.rel_list)) rel_set = set() for triplet in raw_examples: _, r, _ = triplet rel_set.add(r) rel_list = list(sorted(list(rel_set))) # ========= get all embeddings ========== print("get all embeddings") save_dir = args.model_name_or_path if dir_exists(args.model_name_or_path) else args.output_dir save_path = os.path.join(save_dir, "saved_emb_mat.np") if dir_exists(save_dir) and file_exists(save_path): print("\tload from file") emb_mat = torch.load(save_path) else: print("\tget all ids") input_ids_list, mask_ids_list, segment_ids_list = [], [], [] for _ent in tqdm(ent_list): for _idx_r, _rel in enumerate(rel_list): head_ids, rel_ids, tail_ids = standard_dataset.convert_raw_example_to_features( [_ent, _rel, _ent], method="4") head_ids, rel_ids, tail_ids = head_ids[1:-1], rel_ids[1:-1], tail_ids[1:-1] # truncate max_ent_len = standard_dataset.max_seq_length - 3 - len(rel_ids) head_ids = head_ids[:max_ent_len] tail_ids = tail_ids[:max_ent_len] src_input_ids = [standard_dataset._cls_id] + head_ids + [standard_dataset._sep_id] + rel_ids + [ standard_dataset._sep_id] src_mask_ids = [1] * len(src_input_ids) src_segment_ids = [0] * (len(head_ids) + 2) + [1] * (len(rel_ids) + 1) if _idx_r == 0: tgt_input_ids = [standard_dataset._cls_id] + tail_ids + [standard_dataset._sep_id] tgt_mask_ids = [1] * len(tgt_input_ids) tgt_segment_ids = [0] * (len(tail_ids) + 2) input_ids_list.append(tgt_input_ids) mask_ids_list.append(tgt_mask_ids) segment_ids_list.append(tgt_segment_ids) input_ids_list.append(src_input_ids) mask_ids_list.append(src_mask_ids) segment_ids_list.append(src_segment_ids) # # padding max_len = max(len(_e) for _e in input_ids_list) assert max_len <= standard_dataset.max_seq_length input_ids_list = [_e + [standard_dataset._pad_id] * (max_len - len(_e)) for _e in input_ids_list] mask_ids_list = [_e + [0] * (max_len - len(_e)) for _e in mask_ids_list] segment_ids_list = [_e + [0] * (max_len - len(_e)) for _e in segment_ids_list] # # dataset enc_dataset = TensorDataset( torch.tensor(input_ids_list, dtype=torch.long), torch.tensor(mask_ids_list, dtype=torch.long), torch.tensor(segment_ids_list, dtype=torch.long), ) enc_dataloader = DataLoader( enc_dataset, sampler=SequentialSampler(enc_dataset) , batch_size=args.eval_batch_size*2) print("\tget all emb via model") embs_list = [] for batch in tqdm(enc_dataloader, desc="entity embedding", disable=(not verbose)): batch = tuple(t.to(args.device) for t in batch) _input_ids, _mask_ids, _segment_ids = batch with torch.no_grad(): embs = model.encoder(_input_ids, attention_mask=_mask_ids, token_type_ids=_segment_ids) embs = embs.detach().cpu() embs_list.append(embs) emb_mat = torch.cat(embs_list, dim=0).contiguous() assert emb_mat.shape[0] == len(input_ids_list) # save emb_mat if dir_exists(save_dir): torch.save(emb_mat, save_path) # # assign to ent assert len(ent_list) *(1+len(rel_list)) == emb_mat.shape[0] ent_rel2emb = collections.defaultdict(dict) ent2emb = dict() ptr_row = 0 for _ent in ent_list: for _idx_r, _rel in enumerate(rel_list): if _idx_r == 0: ent2emb[_ent] = emb_mat[ptr_row] ptr_row += 1 ent_rel2emb[_ent][_rel] = emb_mat[ptr_row] ptr_row += 1 # ========= run link prediction ========== # * begin to get hit ranks_left, ranks_right, ranks = [], [], [] hits_left, hits_right, hits = [], [], [] top_ten_hit_count = 0 top_five_hit_count = 0 top_one_hit_count = 0 for i in range(10): hits_left.append([]) hits_right.append([]) hits.append([]) for _idx_ex, _triplet in enumerate(tqdm(raw_examples, desc="evaluating")): _head, _rel, _tail = _triplet head_ent_list = [] tail_ent_list = [] # tail corrupt _pos_tail_ents = g_subj2objs[_head][_rel] _neg_tail_ents = ents - _pos_tail_ents _neg_tail_ents = [_ent for _ent in _neg_tail_ents if _ent in standard_dataset.type_dict[_rel]["tail"]] head_ent_list.extend([_head] * (1 + len(_neg_tail_ents))) tail_ent_list.append(_tail) # positive example tail_ent_list.extend(_neg_tail_ents) # negative examples triplet_list = list([_h, _rel, _t] for _h, _t in zip(head_ent_list, tail_ent_list)) # build dataset rep_src_list = [ent_rel2emb[_h][_rel] for _h, _rel, _ in triplet_list] rep_tgt_list = [ent2emb[_t] for _, _rel, _t in triplet_list] all_rep_src = torch.stack(rep_src_list, dim=0).to(args.device) all_rep_tgt = torch.stack(rep_tgt_list, dim=0).to(args.device) local_scores_list = [] sim_batch_size = args.eval_batch_size * 8 if args.cls_method == "dis": for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size): _rep_src, _rep_tgt = all_rep_src[_idx_r: _idx_r + sim_batch_size], all_rep_tgt[ _idx_r: _idx_r + sim_batch_size] with torch.no_grad(): distances = model.distance_metric_fn(_rep_src, _rep_tgt).to(torch.float32) local_scores = - distances local_scores = local_scores.detach().cpu().numpy() local_scores_list.append(local_scores) elif args.cls_method == "cls": for _idx_r in range(0, all_rep_src.shape[0], sim_batch_size): _rep_src, _rep_tgt = all_rep_src[_idx_r: _idx_r + sim_batch_size], all_rep_tgt[ _idx_r: _idx_r + sim_batch_size] with torch.no_grad(): logits = model.classifier(_rep_src, _rep_tgt).to(torch.float32) logits = torch.softmax(logits, dim=-1) local_scores = logits.detach().cpu().numpy()[:, 1] local_scores_list.append(local_scores) scores = np.concatenate(local_scores_list, axis=0) # right right_scores = scores right_rank = safe_ranking(right_scores) ranks_right.append(right_rank) ranks.append(right_rank) # log top_ten_hit_count += int(right_rank <= 10) top_five_hit_count += int(right_rank <= 5) top_one_hit_count += int(right_rank <= 1) if (_idx_ex + 1) % 10 == 0: logger.info("hit@1 until now: {}".format(top_one_hit_count * 1.0 / len(ranks))) logger.info("hit@5 until now: {}".format(top_five_hit_count * 1.0 / len(ranks))) logger.info("hit@10 until now: {}".format(top_ten_hit_count * 1.0 / len(ranks))) # hits for hits_level in range(10): if right_rank <= hits_level + 1: hits[hits_level].append(1.0) hits_right[hits_level].append(1.0) else: hits[hits_level].append(0.0) hits_right[hits_level].append(0.0) if verbose: for i in [0, 4, 9]: logger.info('Hits right @{0}: {1}'.format(i + 1, np.mean(hits_right[i]))) logger.info('Mean rank right: {0}'.format(np.mean(ranks_right))) logger.info('Mean reciprocal rank right: {0}'.format(np.mean(1. / np.array(ranks_right)))) with open(join(args.output_dir, "link_prediction_metrics.txt"), "w", encoding="utf-8") as fp: for i in [0, 4, 9]: fp.write('Hits right @{0}: {1}\n'.format(i + 1, np.mean(hits_right[i]))) fp.write('Mean rank right: {0}\n'.format(np.mean(ranks_right))) fp.write('Mean reciprocal rank right: {0}\n'.format(np.mean(1. / np.array(ranks_right)))) print("save finished!") tuple_ranks = [[int(_l), int(_r)] for _l, _r in zip(ranks_left, ranks_right)] return tuple_ranks