def main(): data_dir = "data/olpbench" entity_mentions,em_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","entity_id_map.txt")) relation_mentions,rm_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","relation_id_map.txt")) random.seed(42) np.random.seed(42) torch.manual_seed(42) test_kb = kb(os.path.join(data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map) xt_lines = open("helper_scripts/tmp/test_data_preds.txt.tail_thorough_f5_d300_e50.stage1",'r').readlines() cache_e = pickle.load(open("helper_scripts/tmp/top_1000_neighbors_val.pkl",'rb')) test_kb.triples = np.delete(test_kb.triples,4996,0) del test_kb.e1_all_answers[4996] del test_kb.e2_all_answers[4996] answers_t = [] for line in tqdm(xt_lines, desc="nudging"): line = line.strip().split("\t") e1 = em_map[line[0]] e1_neighbors = cache_e.get(e1,[]) # [[12,12312],[211,2312],...] tmp_answers_t = ast.literal_eval(line[-1]) for neighbor in e1_neighbors: for i in range(len(tmp_answers_t)): if neighbor[0]==tmp_answers_t[i][0]: tmp_answers_t[i][1] += neighbor[1] answers_t.append(tmp_answers_t) result = utils.get_metrics_using_topk(os.path.join(data_dir,"all_knowns_thorough_linked.pkl"),test_kb,answers_t,answers_t,em_map,rm_map) print(result)
def main(): K = 1000 data_dir = "data/olpbench" entity_mentions,em_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","entity_id_map.txt")) relation_mentions,rm_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","relation_id_map.txt")) random.seed(42) np.random.seed(42) torch.manual_seed(42) test_kb = kb(os.path.join(data_dir,"validation_data_linked.txt"), em_map = em_map, rm_map = rm_map) cache_e = pickle.load(open("helper_scripts/tmp/top_1000_neighbors_val.pkl",'rb')) if (0): graph = pickle.load(open("helper_scripts/tmp/graph_thorough_no_r_count_paths.pkl",'rb')) if(0): train_kb = kb(os.path.join(data_dir,"train_data_thorough.txt"), em_map = None, rm_map = None) graph = {} f = open("helper_scripts/tmp/graph_thorough_no_r_count_paths.pkl",'wb') for triple in tqdm(train_kb.triples): e1 = em_map[triple[0].item()] r = rm_map[triple[1].item()] e2 = em_map[triple[2].item()] if e1 not in graph: graph[e1] = defaultdict(int) graph[e1][e2] += 1 if e2 not in graph: graph[e2] = defaultdict(int) graph[e2][e1] += 1 pickle.dump(graph,f) f.close() exit() manager = Manager() cache_e = manager.dict() process_array = [] for i in range(10): p = mp.Process(target = func, args = (cache_e,test_kb,i,1000,em_map,rm_map,graph,K)) p.start() process_array.append(p) for p in process_array: p.join() cache_e_final = {} for key in cache_e: cache_e_final[key] = cache_e[key] f = open("helper_scripts/tmp/top_10000_neighbors_val.pkl",'wb') pickle.dump(cache_e_final,f) f.close() exit() answers_t = [] answers_h = [] for ind,triple in enumerate(test_kb.triples): e1 = em_map[triple[0].item()] r = rm_map[triple[1].item()] e2 = em_map[triple[2].item()] answers_t.append(cache_e[e1]) answers_h.append(cache_e[e2]) metrics = utils.get_metrics_using_topk(os.path.join(data_dir,"all_knowns_thorough_linked.pkl"),test_kb,answers_t,answers_h,em_map,rm_map) print(metrics)
def main(): data_dir = "../olpbench" freq_r_tail = {} freq_r_head = {} entity_mentions, em_map = utils.read_mentions( os.path.join(data_dir, "mapped_to_ids", "entity_id_map.txt")) _, rm_map = utils.read_mentions( os.path.join(data_dir, "mapped_to_ids", "relation_id_map.txt")) # train_kb = kb(os.path.join(data_dir,"test_data.txt"), em_map = None, rm_map = None) train_kb = kb(os.path.join(data_dir, "train_data_thorough.txt"), em_map=None, rm_map=None) for triple in tqdm(train_kb.triples, desc="getting r freq"): e1 = triple[0].item() r = triple[1].item() e2 = triple[2].item() if r not in freq_r_tail: freq_r_tail[r] = {} if em_map[e2] not in freq_r_tail[r]: freq_r_tail[r][em_map[e2]] = 0 freq_r_tail[r][em_map[e2]] += 1 if r not in freq_r_head: freq_r_head[r] = {} if em_map[e1] not in freq_r_head[r]: freq_r_head[r][em_map[e1]] = 0 freq_r_head[r][em_map[e1]] += 1 f = open("../olpbench/r-freq_top100_thorough_head.pkl", "wb") final_data = {} for r in freq_r_head: final_list = list( zip(list(freq_r_head[r].values()), list(freq_r_head[r].keys()))) final_list.sort(reverse=True) final_list = final_list[:100] final_data[r] = final_list pickle.dump(final_data, f) f.close() f = open("../olpbench/r-freq_top100_thorough_tail.pkl", "wb") final_data = {} for r in freq_r_tail: final_list = list( zip(list(freq_r_tail[r].values()), list(freq_r_tail[r].keys()))) final_list.sort(reverse=True) final_list = final_list[:100] final_data[r] = final_list pickle.dump(final_data, f) f.close()
def main(): data_dir = "data/olpbench" freq_r_tail = {} freq_r_head = {} entity_mentions,em_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","entity_id_map.txt")) _,rm_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","relation_id_map.txt")) train_kb = kb(os.path.join(data_dir,"train_data_thorough.txt"), em_map = None, rm_map = None) for triple in tqdm(train_kb.triples, desc="getting r freq"): e1 = triple[0].item() r = triple[1].item() e2 = triple[2].item() if r not in freq_r_tail: freq_r_tail[r] = {} if em_map[e2] not in freq_r_tail[r]: freq_r_tail[r][em_map[e2]] = 0 freq_r_tail[r][em_map[e2]] += 1 if r not in freq_r_head: freq_r_head[r] = {} if em_map[e1] not in freq_r_head[r]: freq_r_head[r][em_map[e1]] = 0 freq_r_head[r][em_map[e1]] += 1 test_kb = kb(os.path.join(data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map) answers_t = [] answers_h = [] for triple in test_kb.triples: r = triple[1].item() val = freq_r_tail.get(r,{}) this_answer = [] for key in val: this_answer.append([key,val[key]]) answers_t.append(this_answer) val = freq_r_head.get(r,{}) this_answer = [] for key in val: this_answer.append([key,val[key]]) answers_h.append(this_answer) metrics = utils.get_metrics_using_topk(os.path.join(data_dir,"all_knowns_thorough_linked.pkl"),test_kb,answers_t,answers_h,em_map,rm_map) print(metrics)
def load_embedding(self): # return torch.nn.Embedding(self.entity_count,512) # Step 1 get train data # DATA_DIR = "/home/mayank/olpbench" DATA_DIR = "/home/yatin/mayank/olpbench" etokens, etoken_map = utils.get_tokens_map(os.path.join(DATA_DIR,"mapped_to_ids","entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map(os.path.join(DATA_DIR,"mapped_to_ids","relation_token_id_map.txt")) entity_mentions,em_map = utils.read_mentions(os.path.join(DATA_DIR,"mapped_to_ids","entity_id_map.txt")) relation_mentions,rm_map = utils.read_mentions(os.path.join(DATA_DIR,"mapped_to_ids","relation_id_map.txt")) train_kb = kb(os.path.join(DATA_DIR,"train_data_thorough.txt"), em_map = em_map, rm_map = rm_map) # Step 2 get those 2 helper things for relation relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(relation_mentions, rtoken_map, maxlen=10, use_tqdm=True) # Step 3 for each entity get the top frequent relation freq_e = {} for triple in tqdm(train_kb.triples): # e1 = triple[0].item() # r = triple[1].item() # e2 = triple[2].item() e1 = em_map[triple[0].item()] r = rm_map[triple[1].item()] e2 = em_map[triple[2].item()] if e1 not in freq_e: freq_e[e1] = {} if r not in freq_e[e1]: freq_e[e1][r] = 0 freq_e[e1][r] += 1 if e2 not in freq_e: freq_e[e2] = {} if r not in freq_e[e2]: freq_e[e2][r] = 0 freq_e[e2][r] += 1 for key in tqdm(freq_e): freq_e[key] = max(freq_e[key], key = freq_e[key].get) # Step 4 get the embedding for that relation and save it in torch.nn.embedding against that entity from models import complexLSTM_2_all_e model = complexLSTM_2_all_e(196007,39303, 2473409, 512, lstm_dropout=0.1) print("Resuming...") # checkpoint = torch.load("/home/mayank/olpbench/models/author_data_2lstm_thorough_all-e/checkpoint_epoch_43",map_location="cpu") checkpoint = torch.load("/home/yatin/mayank/olpbench/models/checkpoint_epoch_43",map_location="cpu") model.load_state_dict(checkpoint['state_dict']) embedding = torch.nn.Embedding(self.entity_count,512) model.eval() model.to("cuda") for entity in trange(2473409, desc="Creating entity tensors finally!"): # import pdb # pdb.set_trace() if entity not in freq_e: embedding.weight.data[entity] = torch.zeros(512) else: r_mention_tensor, r_lengths = convert_mention_to_token_indices([freq_e[entity]], relation_token_indices, relation_lengths) r_mention_tensor, r_lengths = r_mention_tensor.cuda(), r_lengths.cuda() r_real_lstm, r_img_lstm = model.get_mention_embedding(r_mention_tensor,1,r_lengths) r_real_lstm = r_real_lstm[0] r_img_lstm = r_img_lstm[0] embedding.weight.data[entity] = torch.cat([r_real_lstm, r_img_lstm]).cpu() # import pdb # pdb.set_trace() return embedding
def main(): K = 2 data_dir = "data/olpbench" entity_mentions, em_map = utils.read_mentions( os.path.join(data_dir, "mapped_to_ids", "entity_id_map.txt")) relation_mentions, rm_map = utils.read_mentions( os.path.join(data_dir, "mapped_to_ids", "relation_id_map.txt")) graph = pickle.load(open("helper_scripts/tmp/graph_thorough.pkl", 'rb')) if (0): train_kb = kb(os.path.join(data_dir, "train_data_thorough.txt"), em_map=None, rm_map=None) # train_kb = kb(os.path.join(data_dir,"test_data.txt"), em_map = None, rm_map = None) # train_kb = kb(os.path.join(data_dir,"validation_data_linked.txt"), em_map = None, rm_map = None) graph = {} graph_forward = {} graph_backward = {} rels_for_e1_e2 = {} for triple in tqdm(train_kb.triples): e1 = em_map[triple[0].item()] r = rm_map[triple[1].item()] e2 = em_map[triple[2].item()] if e1 not in graph: graph[e1] = [] graph[e1].append([e2, r]) if e2 not in graph: graph[e2] = [] graph[e2].append([e1, len(relation_mentions) + r]) # ---------------------------------------------------- # if (e1,e2) not in rels_for_e1_e2: # rels_for_e1_e2[(e1,e2)] = [] # rels_for_e1_e2[(e1,e2)].append(r) # if e1 not in graph_forward: # graph_forward[e1] = set() # # graph_forward[e1].add((e2,r)) # graph_forward[e1].add(e2) # if e2 not in graph_backward: # graph_backward[e2] = set() # # graph_backward[e2].add((e1,r)) # graph_backward[e2].add(e1) random.seed(42) np.random.seed(42) torch.manual_seed(42) test_kb = kb(os.path.join(data_dir, "validation_data_linked.txt"), em_map=None, rm_map=None) count = 0 two_hop_data = [] for triple in tqdm(test_kb.triples, desc="test triples"): e1 = em_map[triple[0].item()] r = rm_map[triple[1].item()] e2 = em_map[triple[2].item()] if bfs(e1, e2, K, graph): count += 1 # flag, proof = forw_back_check(e1,e2,graph_forward,graph_backward,rels_for_e1_e2) # if flag: # count += 1 # two_hop_data.append(proof) print(count)
def main(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # for batch in train_loader: # inputs, \ # normalizer_loss, \ # normalizer_metric, \ # labels, \ # label_ids, \ # filter_mask, \ # batch_shared_entities = train_data.input_and_labels_to_device( # batch, # training=True, # device=train_data.device # ) # import pdb # pdb.set_trace() # read token maps etokens, etoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "relation_token_id_map.txt")) entity_mentions, em_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt")) relation_mentions, rm_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt")) # create entity_token_indices and entity_lengths # [[max length indices for entity 0 ], [max length indices for entity 1], [max length indices for entity 2], ...] # [length of entity 0, length of entity 1, length of entity 2, ...] # entity_token_indices, entity_lengths = utils.get_token_indices_from_mention_indices(entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True) # relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(relation_mentions, rtoken_map, maxlen=args.max_seq_length, use_tqdm=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") #train code (+1 for unk token) if args.model == "complex": if args.separate_lstms: model = complexLSTM_2( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) else: model = complexLSTM( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) elif args.model == "rotate": model = rotatELSTM( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, gamma=args.gamma_rotate, lstm_dropout=args.lstm_dropout) if args.do_eval: best_model = -1 best_metrics = None if "olpbench" in args.data_dir: # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map) test_kb = kb(os.path.join(args.data_dir, "test_data.txt"), em_map=em_map, rm_map=rm_map) else: test_kb = kb(os.path.join(args.data_dir, "test.txt"), em_map=em_map, rm_map=rm_map) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2, all_known_e1 = pickle.load( open( os.path.join( args.data_dir, "all_knowns_{}_linked.pkl".format(args.train_data_type)), "rb")) models = os.listdir("models/author_data_2lstm_thorough") for model_path in tqdm(models): try: model_path = os.path.join("models/author_data_2lstm_thorough", model_path) #eval code metrics = {} metrics['mr'] = 0 metrics['mrr'] = 0 metrics['hits1'] = 0 metrics['hits10'] = 0 metrics['hits50'] = 0 metrics['mr_t'] = 0 metrics['mrr_t'] = 0 metrics['hits1_t'] = 0 metrics['hits10_t'] = 0 metrics['hits50_t'] = 0 metrics['mr_h'] = 0 metrics['mrr_h'] = 0 metrics['hits1_h'] = 0 metrics['hits10_h'] = 0 metrics['hits50_h'] = 0 checkpoint = torch.load( model_path, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() # get embeddings for all entity mentions entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices( entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=False) entity_mentions_tensor = entity_mentions_tensor.cuda() entity_mentions_lengths = entity_mentions_lengths.cuda() ementions_real_lis = [] ementions_img_lis = [] split = 100 #cant fit all in gpu together. hence split with torch.no_grad(): for i in range(0, len(entity_mentions_tensor), len(entity_mentions_tensor) // split): data = entity_mentions_tensor[ i:i + len(entity_mentions_tensor) // split, :] data_lengths = entity_mentions_lengths[ i:i + len(entity_mentions_tensor) // split] ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding( data, 0, data_lengths) ementions_real_lis.append(ementions_real_lstm.cpu()) ementions_img_lis.append(ementions_img_lstm.cpu()) del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm torch.cuda.empty_cache() ementions_real = torch.cat(ementions_real_lis).cuda() ementions_img = torch.cat(ementions_img_lis).cuda() ######################################################################## test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length) test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length) test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length) # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map) indices = torch.Tensor( range(len(test_kb.triples)) ) #indices would be used to fetch alternative answers while evaluating test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) split_dim_for_eval = 1 if (args.embedding_dim >= 256 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 4 if (args.embedding_dim >= 512 and "olpbench" in args.data_dir): split_dim_for_eval = 4 if (args.embedding_dim >= 512 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 6 split_dim_for_eval = 1 for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in test_dataloader: test_e1_tokens, test_e1_lengths = test_e1_tokens.to( device), test_e1_lengths.to(device) test_r_tokens, test_r_lengths = test_r_tokens.to( device), test_r_lengths.to(device) test_e2_tokens, test_e2_lengths = test_e2_tokens.to( device), test_e2_lengths.to(device) with torch.no_grad(): e1_real_lstm, e1_img_lstm = model.get_mention_embedding( test_e1_tokens, 0, test_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding( test_r_tokens, 1, test_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding( test_e2_tokens, 0, test_e2_lengths) for count in range(index.shape[0]): # breakpoint() this_e1_real = e1_real_lstm[count].unsqueeze(0) this_e1_img = e1_img_lstm[count].unsqueeze(0) this_r_real = r_real_lstm[count].unsqueeze(0) this_r_img = r_img_lstm[count].unsqueeze(0) this_e2_real = e2_real_lstm[count].unsqueeze(0) this_e2_img = e2_img_lstm[count].unsqueeze(0) simi_t = model.complex_score_e1_r_with_all_ementions( this_e1_real, this_e1_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) simi_h = model.complex_score_e2_r_with_all_ementions( this_e2_real, this_e2_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) # get known answers for filtered ranking ind = index[count] this_correct_mentions_e2 = test_kb.e2_all_answers[int( ind.item())] this_correct_mentions_e1 = test_kb.e1_all_answers[int( ind.item())] all_correct_mentions_e2 = all_known_e2.get( (em_map[test_kb.triples[int(ind.item())][0]], rm_map[test_kb.triples[int(ind.item())][1]]), []) all_correct_mentions_e1 = all_known_e1.get( (em_map[test_kb.triples[int(ind.item())][2]], rm_map[test_kb.triples[int(ind.item())][1]]), []) # compute metrics best_score = simi_t[this_correct_mentions_e2].max() simi_t[ all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_t.ge(best_score).float() equal = simi_t.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 metrics['mr_t'] += rank metrics['mrr_t'] += 1.0 / rank metrics['hits1_t'] += rank.le(1).float() metrics['hits10_t'] += rank.le(10).float() metrics['hits50_t'] += rank.le(50).float() best_score = simi_h[this_correct_mentions_e1].max() simi_h[ all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_h.ge(best_score).float() equal = simi_h.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 metrics['mr_h'] += rank metrics['mrr_h'] += 1.0 / rank metrics['hits1_h'] += rank.le(1).float() metrics['hits10_h'] += rank.le(10).float() metrics['hits50_h'] += rank.le(50).float() metrics['mr'] = (metrics['mr_h'] + metrics['mr_t']) / 2 metrics['mrr'] = (metrics['mrr_h'] + metrics['mrr_t']) / 2 metrics['hits1'] = (metrics['hits1_h'] + metrics['hits1_t']) / 2 metrics['hits10'] = (metrics['hits10_h'] + metrics['hits10_t']) / 2 metrics['hits50'] = (metrics['hits50_h'] + metrics['hits50_t']) / 2 for key in metrics: metrics[key] = metrics[key] / len(test_kb.triples) if best_metrics == None or best_metrics['hits1'] < metrics[ 'hits1']: best_model = model_path best_metrics = metrics print("best_hits1:", best_metrics['hits1']) except: continue print(best_metrics) print(best_model)
def _read(self, labels_fp): labels_list, _ = utils.read_mentions(labels_fp) for mcat_index, label in enumerate(labels_list): yield self.text_to_instance(label, mcat_index)
def main(args): hits_1_triple = [] hits_1_correct_answers = [] hits_1_model_top10 = [] nothits_50_triple = [] nothits_50_correct_answers = [] nothits_50_model_top10 = [] # read token maps etokens, etoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","relation_token_id_map.txt")) entity_mentions,em_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","entity_id_map.txt")) _,rm_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","relation_id_map.txt")) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) #train code (+1 for unk token) model = complexLSTM(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens) if(args.resume): print("Resuming from:",args.resume) checkpoint = torch.load(args.resume,map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() # get embeddings for all entity mentions entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(entity_mentions,etoken_map,maxlen=args.max_seq_length,use_tqdm=True) entity_mentions_tensor = entity_mentions_tensor.cuda() entity_mentions_lengths = entity_mentions_lengths.cuda() ementions_real_lis = [] ementions_img_lis = [] split = 100 #cant fit all in gpu together. hence split with torch.no_grad(): for i in tqdm(range(0,len(entity_mentions_tensor),len(entity_mentions_tensor)//split)): data = entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:] data_lengths = entity_mentions_lengths[i:i+len(entity_mentions_tensor)//split] ementions_real_lstm,ementions_img_lstm = model.get_mention_embedding(data,0,data_lengths) ementions_real_lis.append(ementions_real_lstm.cpu()) ementions_img_lis.append(ementions_img_lstm.cpu()) del entity_mentions_tensor,ementions_real_lstm,ementions_img_lstm torch.cuda.empty_cache() ementions_real = torch.cat(ementions_real_lis).cuda() ementions_img = torch.cat(ementions_img_lis).cuda() ######################################################################## if "olpbench" in args.data_dir: # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map) test_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map) else: test_kb = kb(os.path.join(args.data_dir,"test.txt"), em_map = em_map, rm_map = rm_map) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2,all_known_e1 = pickle.load(open(os.path.join(args.data_dir,"all_knowns_simple_linked.pkl"),"rb")) test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(test_kb.triples[:,0], etoken_map,maxlen=args.max_seq_length) test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(test_kb.triples[:,1], rtoken_map,maxlen=args.max_seq_length) test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(test_kb.triples[:,2], etoken_map,maxlen=args.max_seq_length) # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map) indices = torch.Tensor(range(len(test_kb.triples))) #indices would be used to fetch alternative answers while evaluating test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) split_dim_for_eval = 1 if(args.embedding_dim>=512 and "olpbench" in args.data_dir): split_dim_for_eval = 4 for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(test_dataloader,desc="Test dataloader"): test_e1_tokens, test_e1_lengths = test_e1_tokens.to(device), test_e1_lengths.to(device) test_r_tokens, test_r_lengths = test_r_tokens.to(device), test_r_lengths.to(device) test_e2_tokens, test_e2_lengths = test_e2_tokens.to(device), test_e2_lengths.to(device) with torch.no_grad(): e1_real_lstm, e1_img_lstm = model.get_mention_embedding(test_e1_tokens,0, test_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding(test_r_tokens,1, test_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding(test_e2_tokens,0, test_e2_lengths) for count in tqdm(range(index.shape[0]), desc="Evaluating"): this_e1_real = e1_real_lstm[count].unsqueeze(0) this_e1_img = e1_img_lstm[count].unsqueeze(0) this_r_real = r_real_lstm[count].unsqueeze(0) this_r_img = r_img_lstm[count].unsqueeze(0) this_e2_real = e2_real_lstm[count].unsqueeze(0) this_e2_img = e2_img_lstm[count].unsqueeze(0) # get known answers for filtered ranking ind = index[count] this_correct_mentions_e2 = test_kb.e2_all_answers[int(ind.item())] this_correct_mentions_e1 = test_kb.e1_all_answers[int(ind.item())] all_correct_mentions_e2 = all_known_e2.get((em_map[test_kb.triples[int(ind.item())][0]],rm_map[test_kb.triples[int(ind.item())][1]]),[]) all_correct_mentions_e1 = all_known_e1.get((em_map[test_kb.triples[int(ind.item())][2]],rm_map[test_kb.triples[int(ind.item())][1]]),[]) if(args.head_or_tail=="tail"): simi = model.complex_score_e1_r_with_all_ementions(this_e1_real,this_e1_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0) best_score = simi[this_correct_mentions_e2].max() simi[all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE greatereq = simi.ge(best_score).float() equal = simi.eq(best_score).float() rank = greatereq.sum()+1+equal.sum()/2.0 else: simi = model.complex_score_e2_r_with_all_ementions(this_e2_real,this_e2_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0) best_score = simi[this_correct_mentions_e1].max() simi[all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE greatereq = simi.ge(best_score).float() equal = simi.eq(best_score).float() rank = greatereq.sum()+1+equal.sum()/2.0 if(rank<=1): #hits1 hits_1_triple.append([test_kb.triples[int(ind.item())][0],test_kb.triples[int(ind.item())][1],test_kb.triples[int(ind.item())][2]]) if(args.head_or_tail=="tail"): # hits_1_correct_answers.append(this_correct_mentions_e2) hits_1_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e2]) else: hits_1_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e1]) hits_1_model_top10.append([]) elif(rank>50): #nothits50 nothits_50_triple.append([test_kb.triples[int(ind.item())][0],test_kb.triples[int(ind.item())][1],test_kb.triples[int(ind.item())][2]]) if(args.head_or_tail=="tail"): nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e2]) else: nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e1]) tmp = simi.sort()[1].tolist()[::-1][:10] nothits_50_model_top10.append([entity_mentions[x] for x in tmp]) indices = list(range(len(hits_1_triple))) random.shuffle(indices) indices = indices[:args.sample] for ind in indices: print(ind,"|",hits_1_triple[ind],"|",hits_1_correct_answers[ind],"|",hits_1_model_top10[ind]) print("---------------------------------------------------------------------------------------------") indices = list(range(len(nothits_50_triple))) random.shuffle(indices) indices = indices[:args.sample] for ind in indices: print(ind,"|",nothits_50_triple[ind],"|",nothits_50_correct_answers[ind],"|",nothits_50_model_top10[ind])
def main(): K = 1000000000000000000 data_dir = "data/olpbench" entity_mentions, em_map = utils.read_mentions( os.path.join(data_dir, "mapped_to_ids", "entity_id_map.txt")) relation_mentions, rm_map = utils.read_mentions( os.path.join(data_dir, "mapped_to_ids", "relation_id_map.txt")) random.seed(42) np.random.seed(42) torch.manual_seed(42) NPROCS = 50 LOCAL_SIZE = 10000 // NPROCS # test_kb = kb(os.path.join(data_dir,"validation_data_linked.txt"), em_map = em_map, rm_map = rm_map) test_kb = kb(os.path.join(data_dir, "test_data.txt"), em_map=em_map, rm_map=rm_map) if (1): graph = pickle.load( open("helper_scripts/tmp/graph_thorough_no_r_count_paths.pkl", 'rb')) if (0): train_kb = kb(os.path.join(data_dir, "train_data_thorough.txt"), em_map=None, rm_map=None) graph = {} f = open("helper_scripts/tmp/graph_thorough_no_r_count_paths.pkl", 'wb') for triple in tqdm(train_kb.triples): e1 = em_map[triple[0].item()] r = rm_map[triple[1].item()] e2 = em_map[triple[2].item()] if e1 not in graph: graph[e1] = defaultdict(int) graph[e1][e2] += 1 if e2 not in graph: graph[e2] = defaultdict(int) graph[e2][e1] += 1 pickle.dump(graph, f) f.close() exit() manager = Manager() cache_e = manager.dict() test_numerator = manager.dict() test_denominator = manager.dict() test_sub_lines = manager.dict() for i in range(NPROCS): test_sub_lines[i] = manager.list() process_array = [] for i in range(NPROCS): os.system("mkdir helper_scripts/xt_tmp_folder/" + str(i)) p = mp.Process(target=func, args=(test_numerator, test_denominator, test_sub_lines, cache_e, test_kb, i, LOCAL_SIZE, em_map, rm_map, graph, K)) p.start() process_array.append(p) for p in process_array: p.join() num = 0 den = 0 final_test_lines = [] for key in test_numerator: num += test_numerator[key] den += test_denominator[key] final_test_lines.extend(test_sub_lines[key]) print(num / den) print(num / 10000) f = open( "helper_scripts/tmp/test_data-full_neighbors-subset_2hop.txt.tail.xt", 'w') for line in final_test_lines: f.write(line) f.close()
def main(): K = 2 data_dir = "data/olpbench" entity_mentions, em_map = utils.read_mentions( os.path.join(data_dir, "mapped_to_ids", "entity_id_map.txt")) relation_mentions, rm_map = utils.read_mentions( os.path.join(data_dir, "mapped_to_ids", "relation_id_map.txt")) graph = pickle.load( open("helper_scripts/tmp/graph_thorough_no_r.pkl", 'rb')) if (0): train_kb = kb(os.path.join(data_dir, "train_data_thorough.txt"), em_map=None, rm_map=None) # train_kb = kb(os.path.join(data_dir,"test_data.txt"), em_map = None, rm_map = None) # train_kb = kb(os.path.join(data_dir,"validation_data_linked.txt"), em_map = None, rm_map = None) graph = {} graph_forward = {} graph_backward = {} rels_for_e1_e2 = {} f = open("helper_scripts/tmp/graph_thorough_no_r.pkl", 'wb') for triple in tqdm(train_kb.triples): e1 = em_map[triple[0].item()] r = rm_map[triple[1].item()] e2 = em_map[triple[2].item()] if e1 not in graph: graph[e1] = [] graph[e1].append(e2) if e2 not in graph: graph[e2] = [] graph[e2].append(e1) pickle.dump(graph, f) f.close() exit() # ---------------------------------------------------- # if (e1,e2) not in rels_for_e1_e2: # rels_for_e1_e2[(e1,e2)] = [] # rels_for_e1_e2[(e1,e2)].append(r) # if e1 not in graph_forward: # graph_forward[e1] = set() # # graph_forward[e1].add((e2,r)) # graph_forward[e1].add(e2) # if e2 not in graph_backward: # graph_backward[e2] = set() # # graph_backward[e2].add((e1,r)) # graph_backward[e2].add(e1) random.seed(42) np.random.seed(42) torch.manual_seed(42) test_kb = kb(os.path.join(data_dir, "test_data.txt"), em_map=None, rm_map=None) keshav_xt_lines = open("helper_scripts/tmp/test_data.txt.tail.xt", 'r').readlines() f = open("helper_scripts/tmp/test_data.txt.tail.xt.all-2-hop-neighbors", 'w') for i in range(len(keshav_xt_lines)): keshav_xt_lines[i] = keshav_xt_lines[i][keshav_xt_lines[i].index(" ") + 1:].strip() two_hop_data = [] count = 0 for ind, triple in tqdm(enumerate(test_kb.triples), desc="test triples"): e1 = em_map[triple[0].item()] r = rm_map[triple[1].item()] e2 = em_map[triple[2].item()] e1_neighbours = bfs(e1, K, graph)[:10000] if e2 in e1_neighbours: count += 1 neighbour_string = "" for neighbor in e1_neighbours: neighbour_string += "__label__" + str(neighbor) + " " to_write = neighbour_string + keshav_xt_lines[ind] f.write(to_write + "\n") f.flush() # print(to_write,file = f) print(count) f.close()
def main(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # read token maps etokens, etoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "relation_token_id_map.txt")) entity_mentions, em_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt")) relation_mentions, rm_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt")) # create entity_token_indices and entity_lengths # [[max length indices for entity 0 ], [max length indices for entity 1], [max length indices for entity 2], ...] # [length of entity 0, length of entity 1, length of entity 2, ...] entity_token_indices, entity_lengths = utils.get_token_indices_from_mention_indices( entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True) relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices( relation_mentions, rtoken_map, maxlen=args.max_seq_length, use_tqdm=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") #train code (+1 for unk token) model = complexLSTM_2_all_e_PLT( len(etoken_map) + 1, len(rtoken_map) + 1, len(entity_mentions), args.embedding_dim, {}, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) logsigmoid = torch.nn.LogSigmoid() if (args.do_train): data_config = { 'input_file': 'train_data_thorough.txt', 'batch_size': args.train_batch_size, 'use_batch_shared_entities': True, 'min_size_batch_labels': args.train_batch_size, 'max_size_prefix_label': 64, 'device': 0 } expt_settings = { 'loss': 'bce', 'replace_entities_by_tokens': True, 'replace_relations_by_tokens': True, 'max_lengths_tuple': [10, 10] } train_data = OneToNMentionRelationDataset(dataset_dir=os.path.join( args.data_dir, "mapped_to_ids"), is_training_data=True, **data_config, **expt_settings) train_data.create_data_tensors( dataset_dir=os.path.join(args.data_dir, "mapped_to_ids"), train_input_file='train_data_thorough.txt', valid_input_file='validation_data_linked.txt', test_input_file='test_data.txt', ) train_loader = train_data.get_loader( shuffle=True, num_workers=8, drop_last=True, ) # optimizer = torch.optim.Adagrad(model.parameters(),lr=args.learning_rate,weight_decay=args.weight_decay) optimizer = torch.optim.Adagrad(model.parameters(), lr=args.learning_rate) if (args.resume): print("Resuming from:", args.resume) # checkpoint = torch.load(args.resume,map_location=lambda storage, loc: storage) checkpoint = torch.load(args.resume, map_location="cpu") model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) del checkpoint torch.cuda.empty_cache() #Load other things too if required model.train() # crossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='mean') BCEloss = torch.nn.BCEWithLogitsLoss(reduction='sum') for epoch in tqdm(range(0, args.num_train_epochs), desc="epoch"): iteration = 0 # ct = 0 for batch in tqdm(train_loader, desc="Train dataloader"): # ct+=1 # if (ct==1000): # break inputs, \ normalizer_loss, \ normalizer_metric, \ labels, \ label_ids, \ filter_mask, \ batch_shared_entities = train_data.input_and_labels_to_device( batch, training=True, device="cpu" ) if inputs[0] == None: num_samples_for_head = 0 else: num_samples_for_head = inputs[0][0].shape[0] tree_nodes_head, head_mask, tree_nodes_tail, tail_mask,\ tree_nodes_head_neg, head_mask_neg, tree_nodes_tail_neg, tail_mask_neg = model.get_tree_nodes(batch_shared_entities - 2, labels, num_samples_for_head) loss = torch.tensor(0, device=device) for mode, model_inputs in zip(["head", "tail"], inputs): if model_inputs == None: continue # subtract two from author's indices because our map is 2 less if mode == "head": batch_e2_indices = model_inputs[1] - 2 batch_e2_indices = batch_e2_indices.where( batch_e2_indices != -2, torch.tensor(len(entity_mentions) - 2, dtype=torch.int32)) batch_r_indices = model_inputs[0] - 2 batch_r_indices = batch_r_indices.where( batch_r_indices != -2, torch.tensor(len(relation_mentions) - 2, dtype=torch.int32)) batch_e1_indices = batch_shared_entities - 2 train_r_mention_tensor, train_r_lengths = convert_mention_to_token_indices( batch_r_indices.squeeze(1), relation_token_indices, relation_lengths) train_e2_mention_tensor, train_e2_lengths = convert_mention_to_token_indices( batch_e2_indices.squeeze(1), entity_token_indices, entity_lengths) train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda( ), train_r_lengths.cuda() train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda( ), train_e2_lengths.cuda() # e1_real_lstm, e1_img_lstm = model.get_atomic_entity_embeddings(batch_e1_indices.squeeze(1).long().cuda()) r_real_lstm, r_img_lstm = model.get_mention_embedding( train_r_mention_tensor, 1, train_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding( train_e2_mention_tensor, 0, train_e2_lengths) # import pdb # pdb.set_trace() tmp_nodes_0 = torch.cat( [tree_nodes_head[0], tree_nodes_head_neg[0]], dim=1) tmp_nodes_1 = torch.cat( [tree_nodes_head[1], tree_nodes_head_neg[1]], dim=1) head_mask_neg *= -1 tmp_mask = torch.cat([head_mask, head_mask_neg], dim=1) model_output = model.complex_score_e2_r_with_given_ementions( e2_real_lstm, e2_img_lstm, r_real_lstm, r_img_lstm, tmp_nodes_0, tmp_nodes_1) loss = loss - (logsigmoid( model_output * tmp_mask)).mean() # neg # model_output = model.complex_score_e2_r_with_given_ementions(e2_real_lstm,e2_img_lstm,r_real_lstm,r_img_lstm,tree_nodes_head_neg[0],tree_nodes_head_neg[1]) # loss = loss - (logsigmoid(-1*model_output*head_mask_neg)).mean() else: batch_e1_indices = model_inputs[0] - 2 batch_e1_indices = batch_e1_indices.where( batch_e1_indices != -2, torch.tensor(len(entity_mentions) - 2, dtype=torch.int32)) batch_r_indices = model_inputs[1] - 2 batch_r_indices = batch_r_indices.where( batch_r_indices != -2, torch.tensor(len(relation_mentions) - 2, dtype=torch.int32)) batch_e2_indices = batch_shared_entities - 2 train_e1_mention_tensor, train_e1_lengths = convert_mention_to_token_indices( batch_e1_indices.squeeze(1), entity_token_indices, entity_lengths) train_r_mention_tensor, train_r_lengths = convert_mention_to_token_indices( batch_r_indices.squeeze(1), relation_token_indices, relation_lengths) train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda( ), train_e1_lengths.cuda() train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda( ), train_r_lengths.cuda() e1_real_lstm, e1_img_lstm = model.get_mention_embedding( train_e1_mention_tensor, 0, train_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding( train_r_mention_tensor, 1, train_r_lengths) tmp_nodes_0 = torch.cat( [tree_nodes_tail[0], tree_nodes_tail_neg[0]], dim=1) tmp_nodes_1 = torch.cat( [tree_nodes_tail[1], tree_nodes_tail_neg[1]], dim=1) tail_mask_neg *= -1 tmp_mask = torch.cat([tail_mask, tail_mask_neg], dim=1) model_output = model.complex_score_e1_r_with_given_ementions( e1_real_lstm, e1_img_lstm, r_real_lstm, r_img_lstm, tmp_nodes_0, tmp_nodes_1) loss = loss - (logsigmoid( model_output * tmp_mask)).mean() #neg # model_output = model.complex_score_e1_r_with_given_ementions(e1_real_lstm,e1_img_lstm,r_real_lstm,r_img_lstm,tree_nodes_tail_neg[0],tree_nodes_tail_neg[1]) # loss = loss - (logsigmoid(-1*model_output*tail_mask_neg)).mean() # all_outputs.append(output) # all_outputs = torch.cat(all_outputs) # loss = BCEloss(all_outputs.view(-1),labels.view(-1)) # loss /= normalizer_loss # import pdb # pdb.set_trace() loss.backward() optimizer.step() optimizer.zero_grad() if (iteration % args.print_loss_every == 0): print("Current loss:", loss.item()) iteration += 1 if (epoch % args.save_model_every == 0): utils.save_checkpoint( { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, args.output_dir + "/checkpoint_epoch_{}".format(epoch + 1)) if args.do_eval: #eval code metrics = {} metrics['mr'] = 0 metrics['mrr'] = 0 metrics['hits1'] = 0 metrics['hits10'] = 0 metrics['hits50'] = 0 metrics['mr_t'] = 0 metrics['mrr_t'] = 0 metrics['hits1_t'] = 0 metrics['hits10_t'] = 0 metrics['hits50_t'] = 0 metrics['mr_h'] = 0 metrics['mrr_h'] = 0 metrics['hits1_h'] = 0 metrics['hits10_h'] = 0 metrics['hits50_h'] = 0 if (args.resume and not args.do_train): print("Resuming from:", args.resume) checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() cut_nodes_indices = [] cut_mask = [] for i in model.hsoftmax.nodes_at_cut: cut_nodes_indices.append(model.hsoftmax.node_indices_for_t[i]) cut_mask.append(model.hsoftmax.mask_for_t[i]) cut_mask = torch.tensor(cut_mask, device=device) cut_nodes_indices = torch.tensor(cut_nodes_indices, device=device) cut_nodes_real, cut_nodes_img = model.E_atomic( cut_nodes_indices).chunk(2, dim=-1) # all_nodes_indices = torch.tensor(range(len(entity_mentions)), device=device).unsqueeze(0) # import pdb # pdb.set_trace() # I checked that it can save at max 5 * len(entity_mentions) # ementions_real, ementions_img = model.E_atomic(tree_nodes_indices).chunk(2,dim=-1) # all_nodes_real, all_nodes_img = model.E_atomic(all_nodes_indices).chunk(2,dim=-1) ######################################################################## if "olpbench" in args.data_dir: # test_kb = kb(os.path.join(args.data_dir,"train_data_thorough.txt"), em_map = em_map, rm_map = rm_map) test_kb = kb(os.path.join(args.data_dir, "test_data.txt"), em_map=em_map, rm_map=rm_map) else: test_kb = kb(os.path.join(args.data_dir, "test.txt"), em_map=em_map, rm_map=rm_map) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2, all_known_e1 = pickle.load( open( os.path.join( args.data_dir, "all_knowns_{}_linked.pkl".format(args.train_data_type)), "rb")) test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length) test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length) test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length) # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map) indices = torch.Tensor( range(len(test_kb.triples)) ) #indices would be used to fetch alternative answers while evaluating test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) split_dim_for_eval = 1 # e1_support = [] # e2_support = [] for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm( test_dataloader, desc="Test dataloader"): print(metrics) test_e1_tokens, test_e1_lengths = test_e1_tokens.to( device), test_e1_lengths.to(device) test_r_tokens, test_r_lengths = test_r_tokens.to( device), test_r_lengths.to(device) test_e2_tokens, test_e2_lengths = test_e2_tokens.to( device), test_e2_lengths.to(device) with torch.no_grad(): e1_real_lstm, e1_img_lstm = model.get_mention_embedding( test_e1_tokens, 0, test_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding( test_r_tokens, 1, test_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding( test_e2_tokens, 0, test_e2_lengths) for count in tqdm(range(index.shape[0]), desc="Evaluating"): # breakpoint() this_e1_real = e1_real_lstm[count] this_e1_img = e1_img_lstm[count] this_r_real = r_real_lstm[count] this_r_img = r_img_lstm[count] this_e2_real = e2_real_lstm[count] this_e2_img = e2_img_lstm[count] ind = index[count] this_correct_mentions_e2 = test_kb.e2_all_answers[int( ind.item())] this_correct_mentions_e1 = test_kb.e1_all_answers[int( ind.item())] all_correct_mentions_e2 = all_known_e2.get( (em_map[test_kb.triples[int(ind.item())][0]], rm_map[test_kb.triples[int(ind.item())][1]]), []) all_correct_mentions_e1 = all_known_e1.get( (em_map[test_kb.triples[int(ind.item())][2]], rm_map[test_kb.triples[int(ind.item())][1]]), []) with torch.no_grad(): pass simi_t = model.test_query(this_e1_real, this_e1_img, this_r_real, this_r_img, None, None, "tail", cut_mask, cut_nodes_real, cut_nodes_img) simi_h = model.test_query(None, None, this_r_real, this_r_img, this_e2_real, this_e2_real, "head", cut_mask, cut_nodes_real, cut_nodes_img) # simi_t = model.test_query_debug(this_e1_real, this_e1_img, this_r_real, this_r_img, None, None, "tail", cut_mask, cut_nodes_real, cut_nodes_img,this_correct_mentions_e2) # simi_h = model.test_query_debug(None, None, this_r_real, this_r_img, this_e2_real, this_e2_real, "head", cut_mask, cut_nodes_real, cut_nodes_img,this_correct_mentions_e1) # get known answers for filtered ranking # e1_support.append(len(all_correct_mentions_e1)) # e2_support.append(len(all_correct_mentions_e2)) # compute metrics # for mention in all_correct_mentions_e2: # if mention in simi_t: # rank = torch.tensor(1.).cuda() # break # else: # rank = torch.tensor(2.).cuda() # rank = simi_t best_score = simi_t[this_correct_mentions_e2].max() simi_t[ all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_t.ge(best_score).float() equal = simi_t.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 metrics['mr_t'] += rank metrics['mrr_t'] += 1.0 / rank metrics['hits1_t'] += rank.le(1).float() metrics['hits10_t'] += rank.le(10).float() metrics['hits50_t'] += rank.le(50).float() # for mention in all_correct_mentions_e1: # if mention in simi_h: # rank = torch.tensor(1.).cuda() # break # else: # rank = torch.tensor(2.).cuda() # rank = simi_h best_score = simi_h[this_correct_mentions_e1].max() simi_h[ all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_h.ge(best_score).float() equal = simi_h.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 metrics['mr_h'] += rank metrics['mrr_h'] += 1.0 / rank metrics['hits1_h'] += rank.le(1).float() metrics['hits10_h'] += rank.le(10).float() metrics['hits50_h'] += rank.le(50).float() metrics['mr'] = (metrics['mr_h'] + metrics['mr_t']) / 2 metrics['mrr'] = (metrics['mrr_h'] + metrics['mrr_t']) / 2 metrics['hits1'] = (metrics['hits1_h'] + metrics['hits1_t']) / 2 metrics['hits10'] = (metrics['hits10_h'] + metrics['hits10_t']) / 2 metrics['hits50'] = (metrics['hits50_h'] + metrics['hits50_t']) / 2 # e1_support = torch.tensor(e1_support) # e2_support = torch.tensor(e2_support) # import pdb # pdb.set_trace() for key in metrics: metrics[key] = metrics[key] / len(test_kb.triples) print(metrics)
def main(args): hits_1_triple = [] hits_1_correct_answers = [] hits_1_model_top10 = [] hits_1_evidence = [] baseline_tail_hits1_indices = set([ 36, 91, 95, 101, 119, 158, 282, 397, 638, 728, 740, 763, 914, 959, 972, 992, 1184, 1478, 1669, 1686, 1732, 1795, 1796, 1822, 1826, 1845, 1924, 1939, 1943, 2055, 2178, 2317, 2319, 2325, 2482, 2513, 2589, 2627, 2674, 2736, 2862, 2985, 3049, 3311, 3327, 3491, 3660, 3728, 3817, 3818, 4111, 4263, 4387, 4437, 4438, 4452, 4525, 4591, 4670, 4856, 5114, 5159, 5318, 5587, 5851, 5857, 5893, 5925, 5942, 5990, 6056, 6079, 6119, 6172, 6195, 6211, 6228, 6262, 6267, 6460, 6491, 6509, 6584, 6676, 6699, 6862, 6982, 7057, 7078, 7084, 7221, 7597, 7733, 7837, 8045, 8278, 8326, 8380, 8433, 8453, 8479, 8534, 8540, 8742, 8813, 8860, 8906, 8930, 9234, 9333, 9500, 9535, 9589, 9663, 9803, 9809, 9866, 9999 ]) baseline_correct = 0 # nothits_50_triple = [] # nothits_50_correct_answers = [] # nothits_50_model_top10 = [] injected_rels = kb(args.evidence_file, em_map=None, rm_map=None).triples[:, 1].reshape(-1, args.n_times) # read token maps etokens, etoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "relation_token_id_map.txt")) entity_mentions, em_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt")) _, rm_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt")) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) #train code (+1 for unk token) model = complexLSTM(len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=0) if (args.resume): print("Resuming from:", args.resume) checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() # get embeddings for all entity mentions entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices( entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True) entity_mentions_tensor = entity_mentions_tensor.cuda() entity_mentions_lengths = entity_mentions_lengths.cuda() ementions_real_lis = [] ementions_img_lis = [] split = 100 #cant fit all in gpu together. hence split with torch.no_grad(): for i in tqdm( range(0, len(entity_mentions_tensor), len(entity_mentions_tensor) // split)): data = entity_mentions_tensor[i:i + len(entity_mentions_tensor) // split, :] data_lengths = entity_mentions_lengths[ i:i + len(entity_mentions_tensor) // split] ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding( data, 0, data_lengths) ementions_real_lis.append(ementions_real_lstm.cpu()) ementions_img_lis.append(ementions_img_lstm.cpu()) del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm torch.cuda.empty_cache() ementions_real = torch.cat(ementions_real_lis).cuda() ementions_img = torch.cat(ementions_img_lis).cuda() ######################################################################## if "olpbench" in args.data_dir: test_kb = kb(os.path.join(args.data_dir, "test_data.txt"), em_map=em_map, rm_map=rm_map) else: test_kb = kb(os.path.join(args.data_dir, "test.txt"), em_map=em_map, rm_map=rm_map) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2, all_known_e1 = pickle.load( open(os.path.join(args.data_dir, "all_knowns_thorough_linked.pkl"), "rb")) test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length) test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length) test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length) # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map) indices = torch.Tensor( range(len(test_kb.triples)) ) #indices would be used to fetch alternative answers while evaluating test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) split_dim_for_eval = 1 if (args.embedding_dim >= 512 and "olpbench" in args.data_dir): split_dim_for_eval = 4 for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm( test_dataloader, desc="Test dataloader"): test_e1_tokens, test_e1_lengths = test_e1_tokens.to( device), test_e1_lengths.to(device) test_r_tokens, test_r_lengths = test_r_tokens.to( device), test_r_lengths.to(device) test_e2_tokens, test_e2_lengths = test_e2_tokens.to( device), test_e2_lengths.to(device) with torch.no_grad(): e1_real_lstm, e1_img_lstm = model.get_mention_embedding( test_e1_tokens, 0, test_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding( test_r_tokens, 1, test_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding( test_e2_tokens, 0, test_e2_lengths) for count in tqdm(range(index.shape[0]), desc="Evaluating"): this_e1_real = e1_real_lstm[count].unsqueeze(0) this_e1_img = e1_img_lstm[count].unsqueeze(0) this_r_real = r_real_lstm[count].unsqueeze(0) this_r_img = r_img_lstm[count].unsqueeze(0) this_e2_real = e2_real_lstm[count].unsqueeze(0) this_e2_img = e2_img_lstm[count].unsqueeze(0) # get known answers for filtered ranking ind = index[count] this_correct_mentions_e2 = test_kb.e2_all_answers[int(ind.item())] this_correct_mentions_e1 = test_kb.e1_all_answers[int(ind.item())] all_correct_mentions_e2 = all_known_e2.get( (em_map[test_kb.triples[int(ind.item())][0]], rm_map[test_kb.triples[int(ind.item())][1]]), []) all_correct_mentions_e1 = all_known_e1.get( (em_map[test_kb.triples[int(ind.item())][2]], rm_map[test_kb.triples[int(ind.item())][1]]), []) if (args.head_or_tail == "tail"): simi = model.complex_score_e1_r_with_all_ementions( this_e1_real, this_e1_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) best_score = simi[this_correct_mentions_e2].max() simi[ all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE greatereq = simi.ge(best_score).float() equal = simi.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 else: simi = model.complex_score_e2_r_with_all_ementions( this_e2_real, this_e2_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) best_score = simi[this_correct_mentions_e1].max() simi[ all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE greatereq = simi.ge(best_score).float() equal = simi.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 if int(ind.item()) in baseline_tail_hits1_indices: if rank <= 1: baseline_correct += 1 continue if (rank <= 1): #hits1 hits_1_triple.append([ test_kb.triples[int(ind.item())][0], test_kb.triples[int(ind.item())][1], test_kb.triples[int(ind.item())][2] ]) hits_1_evidence.append(injected_rels[int(ind.item())].tolist()) if (args.head_or_tail == "tail"): # hits_1_correct_answers.append(this_correct_mentions_e2) hits_1_correct_answers.append( [entity_mentions[x] for x in this_correct_mentions_e2]) else: hits_1_correct_answers.append( [entity_mentions[x] for x in this_correct_mentions_e1]) hits_1_model_top10.append([]) # elif(rank>50): # #nothits50 # nothits_50_triple.append([test_kb.triples[int(ind.item())][0],test_kb.triples[int(ind.item())][1],test_kb.triples[int(ind.item())][2]]) # if(args.head_or_tail=="tail"): # nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e2]) # else: # nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e1]) # tmp = simi.sort()[1].tolist()[::-1][:10] # nothits_50_model_top10.append([entity_mentions[x] for x in tmp]) indices = list(range(len(hits_1_triple))) random.shuffle(indices) indices = indices[:args.sample] print(baseline_correct) for ind in indices: print(ind, "|", hits_1_triple[ind], "|", hits_1_correct_answers[ind], "|", hits_1_model_top10[ind], "|", hits_1_evidence[ind])
help="Whether to run eval on the dev set.") parser.add_argument("--eval_batch_size", default=512, type=int, help="Total batch size for eval.") args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # read token maps etokens, etoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "relation_token_id_map.txt")) entity_mentions, em_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt")) relation_mentions, rm_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt")) if args.model == "complex": model = complexLSTM(len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=None, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) elif args.model == "rotate": model = rotatELSTM(len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim,
def main(args): # read token maps etokens, etoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "relation_token_id_map.txt")) entity_mentions, em_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt")) _, rm_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt")) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") #train code (+1 for unk token) if args.model == "complex": if args.separate_lstms: model = complexLSTM_2( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) else: model = complexLSTM( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) elif args.model == "rotate": model = rotatELSTM( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, gamma=args.gamma_rotate, lstm_dropout=args.lstm_dropout) if (args.do_train): optimizer = torch.optim.Adagrad(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) if (args.resume): print("Resuming from:", args.resume) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) #Load other things too if required model.train() if "olpbench" in args.data_dir: train_kb = kb(os.path.join( args.data_dir, "train_data_{}.txt".format(args.train_data_type)), em_map=em_map, rm_map=rm_map) # train_kb = kb(os.path.join(args.data_dir,"train_data_thorough_r_sorted.txt"), em_map = em_map, rm_map = rm_map) # train_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map) else: train_kb = kb(os.path.join(args.data_dir, "train.txt"), em_map=em_map, rm_map=rm_map) train_data = Dataset(train_kb.triples) train_sampler = RandomSampler(train_data, replacement=False) #train_sampler = SequentialSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) # crossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='mean') BCEloss = torch.nn.BCEWithLogitsLoss(reduction='sum') for epoch in tqdm(range(0, args.num_train_epochs), desc="epoch"): iteration = 0 for train_e1_batch, train_r_batch, train_e2_batch in tqdm( train_dataloader, desc="Train dataloader"): # skip this batch if (random.random() < args.skip_train_prob): continue batch_size = len(train_e1_batch) train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices( train_e1_batch, etoken_map, maxlen=args.max_seq_length) train_r_mention_tensor, train_r_lengths = convert_string_to_indices( train_r_batch, rtoken_map, maxlen=args.max_seq_length) train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices( train_e2_batch, etoken_map, maxlen=args.max_seq_length) train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda( ), train_e1_lengths.cuda() train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda( ), train_r_lengths.cuda() train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda( ), train_e2_lengths.cuda() e1_real_lstm, e1_img_lstm = model.get_mention_embedding( train_e1_mention_tensor, 0, train_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding( train_r_mention_tensor, 1, train_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding( train_e2_mention_tensor, 0, train_e2_lengths) #tail simi_t = model.complex_score_e1_r_with_all_ementions( e1_real_lstm, e1_img_lstm, r_real_lstm, r_img_lstm, e2_real_lstm, e2_img_lstm) #head simi_h = model.complex_score_e2_r_with_all_ementions( e2_real_lstm, e2_img_lstm, r_real_lstm, r_img_lstm, e1_real_lstm, e1_img_lstm) # change the loss suitably target = torch.eye(batch_size).cuda() # import pdb # pdb.set_trace() loss_t = BCEloss(simi_t.view(-1), target.view(-1)) loss_h = BCEloss(simi_h.view(-1), target.view(-1)) loss = (loss_h + loss_t) / 2 loss /= target.size(0) * target.size(1) # Do the routine optimizer.zero_grad() loss.backward() #gradient clip? optimizer.step() if (iteration % args.print_loss_every == 0): print("Current loss(avg, tail, head):", loss.item(), loss_t.item(), loss_h.item()) iteration += 1 if (epoch % args.save_model_every == 0 and epoch != 0): utils.save_checkpoint( { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, args.output_dir + "/checkpoint_epoch_{}".format(epoch + 1)) if args.do_eval: #eval code metrics = {} metrics['mr'] = 0 metrics['mrr'] = 0 metrics['hits1'] = 0 metrics['hits10'] = 0 metrics['hits50'] = 0 metrics['mr_t'] = 0 metrics['mrr_t'] = 0 metrics['hits1_t'] = 0 metrics['hits10_t'] = 0 metrics['hits50_t'] = 0 metrics['mr_h'] = 0 metrics['mrr_h'] = 0 metrics['hits1_h'] = 0 metrics['hits10_h'] = 0 metrics['hits50_h'] = 0 if (args.resume and not args.do_train): print("Resuming from:", args.resume) checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() # get embeddings for all entity mentions entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices( entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True) entity_mentions_tensor = entity_mentions_tensor.cuda() entity_mentions_lengths = entity_mentions_lengths.cuda() ementions_real_lis = [] ementions_img_lis = [] split = 100 #cant fit all in gpu together. hence split with torch.no_grad(): for i in tqdm( range(0, len(entity_mentions_tensor), len(entity_mentions_tensor) // split)): data = entity_mentions_tensor[i:i + len(entity_mentions_tensor) // split, :] data_lengths = entity_mentions_lengths[ i:i + len(entity_mentions_tensor) // split] ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding( data, 0, data_lengths) # a = model.Et_im(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]) # b = model.Et_re(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]) # a_lstm,_ = model.lstm(a) # a_lstm = a_lstm[:,-1,:] # b_lstm,_ = model.lstm(b) # b_lstm = b_lstm[:,-1,:] ementions_real_lis.append(ementions_real_lstm.cpu()) ementions_img_lis.append(ementions_img_lstm.cpu()) del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm torch.cuda.empty_cache() ementions_real = torch.cat(ementions_real_lis).cuda() ementions_img = torch.cat(ementions_img_lis).cuda() ######################################################################## if "olpbench" in args.data_dir: # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map) test_kb = kb(os.path.join(args.data_dir, "test_data.txt"), em_map=em_map, rm_map=rm_map) else: test_kb = kb(os.path.join(args.data_dir, "test.txt"), em_map=em_map, rm_map=rm_map) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2, all_known_e1 = pickle.load( open( os.path.join( args.data_dir, "all_knowns_{}_linked.pkl".format(args.train_data_type)), "rb")) test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length) test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length) test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length) # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map) indices = torch.Tensor( range(len(test_kb.triples)) ) #indices would be used to fetch alternative answers while evaluating test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) split_dim_for_eval = 1 if (args.embedding_dim >= 256 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 4 if (args.embedding_dim >= 512 and "olpbench" in args.data_dir): split_dim_for_eval = 4 if (args.embedding_dim >= 512 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 6 split_dim_for_eval = 1 for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm( test_dataloader, desc="Test dataloader"): print(metrics) test_e1_tokens, test_e1_lengths = test_e1_tokens.to( device), test_e1_lengths.to(device) test_r_tokens, test_r_lengths = test_r_tokens.to( device), test_r_lengths.to(device) test_e2_tokens, test_e2_lengths = test_e2_tokens.to( device), test_e2_lengths.to(device) with torch.no_grad(): e1_real_lstm, e1_img_lstm = model.get_mention_embedding( test_e1_tokens, 0, test_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding( test_r_tokens, 1, test_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding( test_e2_tokens, 0, test_e2_lengths) for count in tqdm(range(index.shape[0]), desc="Evaluating"): # breakpoint() this_e1_real = e1_real_lstm[count].unsqueeze(0) this_e1_img = e1_img_lstm[count].unsqueeze(0) this_r_real = r_real_lstm[count].unsqueeze(0) this_r_img = r_img_lstm[count].unsqueeze(0) this_e2_real = e2_real_lstm[count].unsqueeze(0) this_e2_img = e2_img_lstm[count].unsqueeze(0) # import pdb # pdb.set_trace() simi_t = model.complex_score_e1_r_with_all_ementions( this_e1_real, this_e1_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) simi_h = model.complex_score_e2_r_with_all_ementions( this_e2_real, this_e2_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) # get known answers for filtered ranking ind = index[count] this_correct_mentions_e2 = test_kb.e2_all_answers[int( ind.item())] this_correct_mentions_e1 = test_kb.e1_all_answers[int( ind.item())] all_correct_mentions_e2 = all_known_e2.get( (em_map[test_kb.triples[int(ind.item())][0]], rm_map[test_kb.triples[int(ind.item())][1]]), []) all_correct_mentions_e1 = all_known_e1.get( (em_map[test_kb.triples[int(ind.item())][2]], rm_map[test_kb.triples[int(ind.item())][1]]), []) # compute metrics best_score = simi_t[this_correct_mentions_e2].max() simi_t[ all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_t.ge(best_score).float() equal = simi_t.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 metrics['mr_t'] += rank metrics['mrr_t'] += 1.0 / rank metrics['hits1_t'] += rank.le(1).float() metrics['hits10_t'] += rank.le(10).float() metrics['hits50_t'] += rank.le(50).float() best_score = simi_h[this_correct_mentions_e1].max() simi_h[ all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_h.ge(best_score).float() equal = simi_h.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 metrics['mr_h'] += rank metrics['mrr_h'] += 1.0 / rank metrics['hits1_h'] += rank.le(1).float() metrics['hits10_h'] += rank.le(10).float() metrics['hits50_h'] += rank.le(50).float() metrics['mr'] = (metrics['mr_h'] + metrics['mr_t']) / 2 metrics['mrr'] = (metrics['mrr_h'] + metrics['mrr_t']) / 2 metrics['hits1'] = (metrics['hits1_h'] + metrics['hits1_t']) / 2 metrics['hits10'] = (metrics['hits10_h'] + metrics['hits10_t']) / 2 metrics['hits50'] = (metrics['hits50_h'] + metrics['hits50_t']) / 2 for key in metrics: metrics[key] = metrics[key] / len(test_kb.triples) print(metrics)
def main(): data_dir = "data/olpbench" head_or_tail = "tail" sample = 100 train_kb = kb(os.path.join(data_dir, "train_data_thorough.txt"), em_map=None, rm_map=None) freq_r = {} freq_r_e2 = {} for triple in train_kb.triples: e1 = triple[0].item() r = triple[1].item() e2 = triple[2].item() if r not in freq_r: freq_r[r] = 0 freq_r[r] += 1 # if (r,e2) not in freq_r_e2: # freq_r_e2[(r,e2)] = 0 # freq_r_e2[(r,e2)] += 1 if r not in freq_r_e2: freq_r_e2[r] = {} if e2 not in freq_r_e2[r]: freq_r_e2[r][e2] = 0 freq_r_e2[r][e2] += 1 pred_file = "helper_scripts/keshav_xt-preds/validation_data_linked_mention.txt.tail_thorough_n2_e50_lr0.1.stage1" _, rm_map = utils.read_mentions( os.path.join(data_dir, "mapped_to_ids", "relation_id_map.txt")) entity_mentions, em_map = utils.read_mentions( os.path.join(data_dir, "mapped_to_ids", "entity_id_map.txt")) random.seed(42) np.random.seed(42) torch.manual_seed(42) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2, all_known_e1 = pickle.load( open(os.path.join(data_dir, "all_knowns_thorough_linked.pkl"), "rb")) hits_1_triple = [] hits_1_correct_answers = [] hits_1_model_top10 = [] nothits_50_triple = [] nothits_50_correct_answers = [] nothits_50_model_top10 = [] ranks = [] lines = open(pred_file).readlines() for line in tqdm(lines, desc="preds"): line = line.strip().split("\t") e1 = line[0] r = line[1] e2 = line[2] this_correct_mentions_e2_raw = line[4].split("|||") this_correct_mentions_e2 = [] for mention in this_correct_mentions_e2_raw: if mention in em_map: this_correct_mentions_e2.append(em_map[mention]) this_correct_mentions_e1_raw = line[3].split("|||") this_correct_mentions_e1 = [] for mention in this_correct_mentions_e1_raw: if mention in em_map: this_correct_mentions_e1.append(em_map[mention]) all_correct_mentions_e2 = all_known_e2.get((em_map[e1], rm_map[r]), []) all_correct_mentions_e1 = all_known_e1.get((em_map[e2], rm_map[r]), []) indices_scores = torch.tensor(ast.literal_eval(line[5])) topk_scores = indices_scores[:, 1] indices = indices_scores[:, 0].long() if (head_or_tail == "tail"): this_gold = this_correct_mentions_e2 all_gold = all_correct_mentions_e2 else: this_gold = this_correct_mentions_e1 all_gold = all_correct_mentions_e1 best_score = -2000000000 for i, j in enumerate(indices): if j in this_gold: best_score = max(best_score, topk_scores[i].item()) topk_scores[i] = -2000000000 for i, j in enumerate(indices): if j in all_gold: topk_scores[i] = -2000000000 greatereq = topk_scores.ge(best_score).float() equal = topk_scores.eq(best_score).float() rank = (greatereq.sum() + 1 + equal.sum() / 2.0).item() if rank <= 1: hits_1_triple.append([e1, r, e2]) hits_1_correct_answers.append( [entity_mentions[x] for x in this_gold]) hits_1_model_top10.append([]) elif rank > 50: nothits_50_triple.append([e1, r, e2]) nothits_50_correct_answers.append( [entity_mentions[x] for x in this_gold]) nothits_50_model_top10.append( [entity_mentions[x.item()] for x in indices]) ranks.append(rank) result = {} result["hits1"] = 0 result["hits10"] = 0 result["hits50"] = 0 for rank in ranks: if rank <= 1: result["hits1"] += 1 if rank <= 10: result["hits10"] += 1 if rank <= 50: result["hits50"] += 1 result["hits1"] /= len(lines) result["hits10"] /= len(lines) result["hits50"] /= len(lines) print(result) # print format # triple, correct answers, model predictions, r: freq, r_gold: freq, r_prediction: freq, r_max-e2: freq indices = list(range(len(hits_1_triple))) random.shuffle(indices) indices = indices[:sample] for ind in indices: # ratio = " {} /{}".format(freq_r_e2.get((hits_1_triple[ind][1],hits_1_triple[ind][2]),0),freq_r.get(hits_1_triple[ind][1],0)) freq_of_r = freq_r.get(hits_1_triple[ind][1], 0) freq_of_r_gold = freq_r_e2.get(hits_1_triple[ind][1], {}).get(hits_1_triple[ind][2], 0) freq_of_r_pred = "N/A" freq_of_r_maxe = max( list(freq_r_e2.get(hits_1_triple[ind][1], { 0: 0 }).values())) print(hits_1_triple[ind], "|", hits_1_correct_answers[ind], "|", hits_1_model_top10[ind], "|", freq_of_r, "|", freq_of_r_gold, "|", freq_of_r_pred, "|", freq_of_r_maxe) print( "---------------------------------------------------------------------------------------------" ) indices = list(range(len(nothits_50_triple))) random.shuffle(indices) indices = indices[:sample] for ind in indices: # ratio = " {} /{}".format(freq_r_e2.get((nothits_50_triple[ind][1],nothits_50_triple[ind][2]),0),freq_r.get(nothits_50_triple[ind][1],0)) freq_of_r = freq_r.get(nothits_50_triple[ind][1], 0) freq_of_r_gold = freq_r_e2.get(nothits_50_triple[ind][1], {}).get(nothits_50_triple[ind][2], 0) freq_of_r_pred = freq_r_e2.get(nothits_50_triple[ind][1], {}).get(nothits_50_model_top10[ind][0], 0) freq_of_r_maxe = max( list(freq_r_e2.get(nothits_50_triple[ind][1], { 0: 0 }).values())) print(nothits_50_triple[ind], "|", nothits_50_correct_answers[ind], "|", nothits_50_model_top10[ind], "|", freq_of_r, "|", freq_of_r_gold, "|", freq_of_r_pred, "|", freq_of_r_maxe)
def main(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # for batch in train_loader: # inputs, \ # normalizer_loss, \ # normalizer_metric, \ # labels, \ # label_ids, \ # filter_mask, \ # batch_shared_entities = train_data.input_and_labels_to_device( # batch, # training=True, # device=train_data.device # ) # import pdb # pdb.set_trace() # read token maps etokens, etoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","relation_token_id_map.txt")) entity_mentions,em_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","entity_id_map.txt")) relation_mentions,rm_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","relation_id_map.txt")) # create entity_token_indices and entity_lengths # [[max length indices for entity 0 ], [max length indices for entity 1], [max length indices for entity 2], ...] # [length of entity 0, length of entity 1, length of entity 2, ...] entity_token_indices, entity_lengths = utils.get_token_indices_from_mention_indices(entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True) relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(relation_mentions, rtoken_map, maxlen=args.max_seq_length, use_tqdm=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not args.do_train and not args.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") #train code (+1 for unk token) if args.model=="complex": if args.separate_lstms: model = complexLSTM_2(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, lstm_dropout=args.lstm_dropout) else: model = complexLSTM(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, lstm_dropout=args.lstm_dropout) elif args.model == "rotate": model = rotatELSTM(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, gamma = args.gamma_rotate, lstm_dropout=args.lstm_dropout) if(args.do_train): data_config = {'input_file': 'train_data_thorough.txt', 'batch_size': args.train_batch_size, 'use_batch_shared_entities': True, 'min_size_batch_labels': args.train_batch_size, 'max_size_prefix_label': 64, 'device': 0} expt_settings = {'loss': 'bce', 'replace_entities_by_tokens': True, 'replace_relations_by_tokens': True, 'max_lengths_tuple': [10, 10]} train_data = OneToNMentionRelationDataset(dataset_dir=os.path.join(args.data_dir,"mapped_to_ids"), is_training_data=True, **data_config, **expt_settings) train_data.create_data_tensors( dataset_dir=os.path.join(args.data_dir,"mapped_to_ids"), train_input_file='train_data_thorough.txt', valid_input_file='validation_data_linked.txt', test_input_file='test_data.txt', ) train_loader = train_data.get_loader( shuffle=True, num_workers=8, drop_last=True, ) optimizer = torch.optim.Adagrad(model.parameters(),lr=args.learning_rate,weight_decay=args.weight_decay) if(args.resume): print("Resuming from:",args.resume) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) #Load other things too if required model.train() # if "olpbench" in args.data_dir: # train_kb = kb(os.path.join(args.data_dir,"train_data_{}.txt".format(args.train_data_type)), em_map = em_map, rm_map = rm_map) # # train_kb = kb(os.path.join(args.data_dir,"train_data_thorough_r_sorted.txt"), em_map = em_map, rm_map = rm_map) # # train_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map) # else: # train_kb = kb(os.path.join(args.data_dir,"train.txt"), em_map = em_map, rm_map = rm_map) # train_data = Dataset(train_kb.triples) # train_sampler = RandomSampler(train_data,replacement=False) # #train_sampler = SequentialSampler(train_data) # train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) # crossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='mean') BCEloss = torch.nn.BCEWithLogitsLoss(reduction='sum') for epoch in tqdm(range(0,args.num_train_epochs), desc="epoch"): iteration = 0 for batch in tqdm(train_loader, desc="Train dataloader"): inputs, \ normalizer_loss, \ normalizer_metric, \ labels, \ label_ids, \ filter_mask, \ batch_shared_entities = train_data.input_and_labels_to_device( batch, training=True, device="cpu" ) labels = labels.cuda() all_outputs = [] for mode,model_inputs in zip(["head","tail"],inputs): if model_inputs==None: continue # subtract two from author's indices because our map is 2 less if mode=="head": batch_e2_indices = model_inputs[1] - 2 batch_r_indices = model_inputs[0] - 2 batch_e1_indices = batch_shared_entities - 2 else: batch_e1_indices = model_inputs[0] - 2 batch_r_indices = model_inputs[1] - 2 batch_e2_indices = batch_shared_entities - 2 # import pdb # pdb.set_trace() # convert these indices back into string (compatibility with my code) # tik = time.time() # batch_e1_strings = convert_mention_index_to_string(batch_e1_indices.squeeze(1), entity_mentions) # batch_r_strings = convert_mention_index_to_string(batch_r_indices.squeeze(1), relation_mentions) # batch_e2_strings = convert_mention_index_to_string(batch_e2_indices.squeeze(1), entity_mentions) # # print("convert_mention_index_to_string:",time.time() - tik) # # do what you used to do now # # tik = time.time() # train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices(batch_e1_strings,etoken_map,maxlen=args.max_seq_length) # train_r_mention_tensor, train_r_lengths = convert_string_to_indices(batch_r_strings,rtoken_map,maxlen=args.max_seq_length) # train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices(batch_e2_strings,etoken_map,maxlen=args.max_seq_length) # print("convert_string_to_indices:",time.time() - tik) train_e1_mention_tensor, train_e1_lengths = convert_mention_to_token_indices(batch_e1_indices.squeeze(1), entity_token_indices, entity_lengths) train_r_mention_tensor, train_r_lengths = convert_mention_to_token_indices(batch_r_indices.squeeze(1), relation_token_indices, relation_lengths) train_e2_mention_tensor, train_e2_lengths = convert_mention_to_token_indices(batch_e2_indices.squeeze(1), entity_token_indices, entity_lengths) train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(), train_e1_lengths.cuda() train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda(), train_r_lengths.cuda() train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(), train_e2_lengths.cuda() # tik = time.time() e1_real_lstm, e1_img_lstm = model.get_mention_embedding(train_e1_mention_tensor,0,train_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding(train_r_mention_tensor,1,train_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding(train_e2_mention_tensor,0,train_e2_lengths) # print("get_mention_embedding:",time.time() - tik) # tik = time.time() if mode=="head": output = model.complex_score_e2_r_with_all_ementions(e2_real_lstm,e2_img_lstm,r_real_lstm,r_img_lstm,e1_real_lstm,e1_img_lstm) else: output = model.complex_score_e1_r_with_all_ementions(e1_real_lstm,e1_img_lstm,r_real_lstm,r_img_lstm,e2_real_lstm,e2_img_lstm) # print("model_scoring:",time.time() - tik) all_outputs.append(output) all_outputs = torch.cat(all_outputs) loss = BCEloss(all_outputs.view(-1),labels.view(-1)) # loss = loss.sum() loss /= normalizer_loss optimizer.zero_grad() loss.backward() optimizer.step() if(iteration%args.print_loss_every==0): print("Current loss:",loss.item()) iteration+=1 if(epoch%args.save_model_every==0): utils.save_checkpoint({ 'state_dict':model.state_dict(), 'optimizer':optimizer.state_dict() },args.output_dir+"/checkpoint_epoch_{}".format(epoch+1)) # for epoch in tqdm(range(0,args.num_train_epochs),desc="epoch"): # iteration = 0 # for train_e1_batch, train_r_batch, train_e2_batch in tqdm(train_dataloader,desc="Train dataloader"): # # skip this batch # if(random.random()<args.skip_train_prob): # continue # batch_size = len(train_e1_batch) # train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices(train_e1_batch,etoken_map,maxlen=args.max_seq_length) # train_r_mention_tensor, train_r_lengths = convert_string_to_indices(train_r_batch,rtoken_map,maxlen=args.max_seq_length) # train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices(train_e2_batch,etoken_map,maxlen=args.max_seq_length) # train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(), train_e1_lengths.cuda() # train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda(), train_r_lengths.cuda() # train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(), train_e2_lengths.cuda() # e1_real_lstm, e1_img_lstm = model.get_mention_embedding(train_e1_mention_tensor,0,train_e1_lengths) # r_real_lstm, r_img_lstm = model.get_mention_embedding(train_r_mention_tensor,1,train_r_lengths) # e2_real_lstm, e2_img_lstm = model.get_mention_embedding(train_e2_mention_tensor,0,train_e2_lengths) # #tail # simi_t = model.complex_score_e1_r_with_all_ementions(e1_real_lstm,e1_img_lstm,r_real_lstm,r_img_lstm,e2_real_lstm,e2_img_lstm) # #head # simi_h = model.complex_score_e2_r_with_all_ementions(e2_real_lstm,e2_img_lstm,r_real_lstm,r_img_lstm,e1_real_lstm,e1_img_lstm) # # change the loss suitably # target = torch.eye(batch_size).cuda() # # import pdb # # pdb.set_trace() # loss_t = BCEloss(simi_t.view(-1),target.view(-1)) # loss_h = BCEloss(simi_h.view(-1),target.view(-1)) # loss = (loss_h+loss_t)/2 # loss /= target.size(0) * target.size(1) # # Do the routine # optimizer.zero_grad() # loss.backward() # #gradient clip? # optimizer.step() # if(iteration%args.print_loss_every==0): # print("Current loss(avg, tail, head):",loss.item(), loss_t.item(), loss_h.item()) # iteration+=1 # if(epoch%args.save_model_every==0 and epoch!=0): # utils.save_checkpoint({ # 'state_dict':model.state_dict(), # 'optimizer':optimizer.state_dict() # },args.output_dir+"/checkpoint_epoch_{}".format(epoch+1)) if args.do_eval: #eval code metrics = {} metrics['mr'] = 0 metrics['mrr'] = 0 metrics['hits1'] = 0 metrics['hits10'] = 0 metrics['hits50'] = 0 metrics['mr_t'] = 0 metrics['mrr_t'] = 0 metrics['hits1_t'] = 0 metrics['hits10_t'] = 0 metrics['hits50_t'] = 0 metrics['mr_h'] = 0 metrics['mrr_h'] = 0 metrics['hits1_h'] = 0 metrics['hits10_h'] = 0 metrics['hits50_h'] = 0 if(args.resume and not args.do_train): print("Resuming from:",args.resume) checkpoint = torch.load(args.resume,map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() # get embeddings for all entity mentions entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(entity_mentions,etoken_map,maxlen=args.max_seq_length,use_tqdm=True) entity_mentions_tensor = entity_mentions_tensor.cuda() entity_mentions_lengths = entity_mentions_lengths.cuda() ementions_real_lis = [] ementions_img_lis = [] split = 100 #cant fit all in gpu together. hence split with torch.no_grad(): for i in tqdm(range(0,len(entity_mentions_tensor),len(entity_mentions_tensor)//split)): data = entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:] data_lengths = entity_mentions_lengths[i:i+len(entity_mentions_tensor)//split] ementions_real_lstm,ementions_img_lstm = model.get_mention_embedding(data,0,data_lengths) # a = model.Et_im(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]) # b = model.Et_re(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]) # a_lstm,_ = model.lstm(a) # a_lstm = a_lstm[:,-1,:] # b_lstm,_ = model.lstm(b) # b_lstm = b_lstm[:,-1,:] ementions_real_lis.append(ementions_real_lstm.cpu()) ementions_img_lis.append(ementions_img_lstm.cpu()) del entity_mentions_tensor,ementions_real_lstm,ementions_img_lstm torch.cuda.empty_cache() ementions_real = torch.cat(ementions_real_lis).cuda() ementions_img = torch.cat(ementions_img_lis).cuda() ######################################################################## if "olpbench" in args.data_dir: # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map) test_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map) else: test_kb = kb(os.path.join(args.data_dir,"test.txt"), em_map = em_map, rm_map = rm_map) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2,all_known_e1 = pickle.load(open(os.path.join(args.data_dir,"all_knowns_{}_linked.pkl".format(args.train_data_type)),"rb")) test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(test_kb.triples[:,0], etoken_map,maxlen=args.max_seq_length) test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(test_kb.triples[:,1], rtoken_map,maxlen=args.max_seq_length) test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(test_kb.triples[:,2], etoken_map,maxlen=args.max_seq_length) # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map) indices = torch.Tensor(range(len(test_kb.triples))) #indices would be used to fetch alternative answers while evaluating test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) split_dim_for_eval = 1 if(args.embedding_dim>=256 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 4 if(args.embedding_dim>=512 and "olpbench" in args.data_dir): split_dim_for_eval = 4 if(args.embedding_dim>=512 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 6 split_dim_for_eval = 1 for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(test_dataloader,desc="Test dataloader"): print(metrics) test_e1_tokens, test_e1_lengths = test_e1_tokens.to(device), test_e1_lengths.to(device) test_r_tokens, test_r_lengths = test_r_tokens.to(device), test_r_lengths.to(device) test_e2_tokens, test_e2_lengths = test_e2_tokens.to(device), test_e2_lengths.to(device) with torch.no_grad(): e1_real_lstm, e1_img_lstm = model.get_mention_embedding(test_e1_tokens,0, test_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding(test_r_tokens,1, test_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding(test_e2_tokens,0, test_e2_lengths) for count in tqdm(range(index.shape[0]), desc="Evaluating"): # breakpoint() this_e1_real = e1_real_lstm[count].unsqueeze(0) this_e1_img = e1_img_lstm[count].unsqueeze(0) this_r_real = r_real_lstm[count].unsqueeze(0) this_r_img = r_img_lstm[count].unsqueeze(0) this_e2_real = e2_real_lstm[count].unsqueeze(0) this_e2_img = e2_img_lstm[count].unsqueeze(0) simi_t = model.complex_score_e1_r_with_all_ementions(this_e1_real,this_e1_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0) simi_h = model.complex_score_e2_r_with_all_ementions(this_e2_real,this_e2_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0) # get known answers for filtered ranking ind = index[count] this_correct_mentions_e2 = test_kb.e2_all_answers[int(ind.item())] this_correct_mentions_e1 = test_kb.e1_all_answers[int(ind.item())] all_correct_mentions_e2 = all_known_e2.get((em_map[test_kb.triples[int(ind.item())][0]],rm_map[test_kb.triples[int(ind.item())][1]]),[]) all_correct_mentions_e1 = all_known_e1.get((em_map[test_kb.triples[int(ind.item())][2]],rm_map[test_kb.triples[int(ind.item())][1]]),[]) # compute metrics best_score = simi_t[this_correct_mentions_e2].max() simi_t[all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_t.ge(best_score).float() equal = simi_t.eq(best_score).float() rank = greatereq.sum()+1+equal.sum()/2.0 metrics['mr_t'] += rank metrics['mrr_t'] += 1.0/rank metrics['hits1_t'] += rank.le(1).float() metrics['hits10_t'] += rank.le(10).float() metrics['hits50_t'] += rank.le(50).float() best_score = simi_h[this_correct_mentions_e1].max() simi_h[all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_h.ge(best_score).float() equal = simi_h.eq(best_score).float() rank = greatereq.sum()+1+equal.sum()/2.0 metrics['mr_h'] += rank metrics['mrr_h'] += 1.0/rank metrics['hits1_h'] += rank.le(1).float() metrics['hits10_h'] += rank.le(10).float() metrics['hits50_h'] += rank.le(50).float() metrics['mr'] = (metrics['mr_h']+metrics['mr_t'])/2 metrics['mrr'] = (metrics['mrr_h']+metrics['mrr_t'])/2 metrics['hits1'] = (metrics['hits1_h']+metrics['hits1_t'])/2 metrics['hits10'] = (metrics['hits10_h']+metrics['hits10_t'])/2 metrics['hits50'] = (metrics['hits50_h']+metrics['hits50_t'])/2 for key in metrics: metrics[key] = metrics[key] / len(test_kb.triples) print(metrics)