def main(args): # read token maps etokens, etoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "relation_token_id_map.txt")) entity_mentions, em_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt")) _, rm_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt")) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") #train code (+1 for unk token) if args.model == "complex": if args.separate_lstms: model = complexLSTM_2( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) else: model = complexLSTM( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) elif args.model == "rotate": model = rotatELSTM( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, gamma=args.gamma_rotate, lstm_dropout=args.lstm_dropout) if (args.do_train): optimizer = torch.optim.Adagrad(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) if (args.resume): print("Resuming from:", args.resume) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) #Load other things too if required model.train() if "olpbench" in args.data_dir: train_kb = kb(os.path.join( args.data_dir, "train_data_{}.txt".format(args.train_data_type)), em_map=em_map, rm_map=rm_map) # train_kb = kb(os.path.join(args.data_dir,"train_data_thorough_r_sorted.txt"), em_map = em_map, rm_map = rm_map) # train_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map) else: train_kb = kb(os.path.join(args.data_dir, "train.txt"), em_map=em_map, rm_map=rm_map) train_data = Dataset(train_kb.triples) train_sampler = RandomSampler(train_data, replacement=False) #train_sampler = SequentialSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) # crossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='mean') BCEloss = torch.nn.BCEWithLogitsLoss(reduction='sum') for epoch in tqdm(range(0, args.num_train_epochs), desc="epoch"): iteration = 0 for train_e1_batch, train_r_batch, train_e2_batch in tqdm( train_dataloader, desc="Train dataloader"): # skip this batch if (random.random() < args.skip_train_prob): continue batch_size = len(train_e1_batch) train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices( train_e1_batch, etoken_map, maxlen=args.max_seq_length) train_r_mention_tensor, train_r_lengths = convert_string_to_indices( train_r_batch, rtoken_map, maxlen=args.max_seq_length) train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices( train_e2_batch, etoken_map, maxlen=args.max_seq_length) train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda( ), train_e1_lengths.cuda() train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda( ), train_r_lengths.cuda() train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda( ), train_e2_lengths.cuda() e1_real_lstm, e1_img_lstm = model.get_mention_embedding( train_e1_mention_tensor, 0, train_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding( train_r_mention_tensor, 1, train_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding( train_e2_mention_tensor, 0, train_e2_lengths) #tail simi_t = model.complex_score_e1_r_with_all_ementions( e1_real_lstm, e1_img_lstm, r_real_lstm, r_img_lstm, e2_real_lstm, e2_img_lstm) #head simi_h = model.complex_score_e2_r_with_all_ementions( e2_real_lstm, e2_img_lstm, r_real_lstm, r_img_lstm, e1_real_lstm, e1_img_lstm) # change the loss suitably target = torch.eye(batch_size).cuda() # import pdb # pdb.set_trace() loss_t = BCEloss(simi_t.view(-1), target.view(-1)) loss_h = BCEloss(simi_h.view(-1), target.view(-1)) loss = (loss_h + loss_t) / 2 loss /= target.size(0) * target.size(1) # Do the routine optimizer.zero_grad() loss.backward() #gradient clip? optimizer.step() if (iteration % args.print_loss_every == 0): print("Current loss(avg, tail, head):", loss.item(), loss_t.item(), loss_h.item()) iteration += 1 if (epoch % args.save_model_every == 0 and epoch != 0): utils.save_checkpoint( { 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() }, args.output_dir + "/checkpoint_epoch_{}".format(epoch + 1)) if args.do_eval: #eval code metrics = {} metrics['mr'] = 0 metrics['mrr'] = 0 metrics['hits1'] = 0 metrics['hits10'] = 0 metrics['hits50'] = 0 metrics['mr_t'] = 0 metrics['mrr_t'] = 0 metrics['hits1_t'] = 0 metrics['hits10_t'] = 0 metrics['hits50_t'] = 0 metrics['mr_h'] = 0 metrics['mrr_h'] = 0 metrics['hits1_h'] = 0 metrics['hits10_h'] = 0 metrics['hits50_h'] = 0 if (args.resume and not args.do_train): print("Resuming from:", args.resume) checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() # get embeddings for all entity mentions entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices( entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True) entity_mentions_tensor = entity_mentions_tensor.cuda() entity_mentions_lengths = entity_mentions_lengths.cuda() ementions_real_lis = [] ementions_img_lis = [] split = 100 #cant fit all in gpu together. hence split with torch.no_grad(): for i in tqdm( range(0, len(entity_mentions_tensor), len(entity_mentions_tensor) // split)): data = entity_mentions_tensor[i:i + len(entity_mentions_tensor) // split, :] data_lengths = entity_mentions_lengths[ i:i + len(entity_mentions_tensor) // split] ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding( data, 0, data_lengths) # a = model.Et_im(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]) # b = model.Et_re(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]) # a_lstm,_ = model.lstm(a) # a_lstm = a_lstm[:,-1,:] # b_lstm,_ = model.lstm(b) # b_lstm = b_lstm[:,-1,:] ementions_real_lis.append(ementions_real_lstm.cpu()) ementions_img_lis.append(ementions_img_lstm.cpu()) del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm torch.cuda.empty_cache() ementions_real = torch.cat(ementions_real_lis).cuda() ementions_img = torch.cat(ementions_img_lis).cuda() ######################################################################## if "olpbench" in args.data_dir: # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map) test_kb = kb(os.path.join(args.data_dir, "test_data.txt"), em_map=em_map, rm_map=rm_map) else: test_kb = kb(os.path.join(args.data_dir, "test.txt"), em_map=em_map, rm_map=rm_map) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2, all_known_e1 = pickle.load( open( os.path.join( args.data_dir, "all_knowns_{}_linked.pkl".format(args.train_data_type)), "rb")) test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length) test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length) test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length) # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map) indices = torch.Tensor( range(len(test_kb.triples)) ) #indices would be used to fetch alternative answers while evaluating test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) split_dim_for_eval = 1 if (args.embedding_dim >= 256 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 4 if (args.embedding_dim >= 512 and "olpbench" in args.data_dir): split_dim_for_eval = 4 if (args.embedding_dim >= 512 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 6 split_dim_for_eval = 1 for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm( test_dataloader, desc="Test dataloader"): print(metrics) test_e1_tokens, test_e1_lengths = test_e1_tokens.to( device), test_e1_lengths.to(device) test_r_tokens, test_r_lengths = test_r_tokens.to( device), test_r_lengths.to(device) test_e2_tokens, test_e2_lengths = test_e2_tokens.to( device), test_e2_lengths.to(device) with torch.no_grad(): e1_real_lstm, e1_img_lstm = model.get_mention_embedding( test_e1_tokens, 0, test_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding( test_r_tokens, 1, test_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding( test_e2_tokens, 0, test_e2_lengths) for count in tqdm(range(index.shape[0]), desc="Evaluating"): # breakpoint() this_e1_real = e1_real_lstm[count].unsqueeze(0) this_e1_img = e1_img_lstm[count].unsqueeze(0) this_r_real = r_real_lstm[count].unsqueeze(0) this_r_img = r_img_lstm[count].unsqueeze(0) this_e2_real = e2_real_lstm[count].unsqueeze(0) this_e2_img = e2_img_lstm[count].unsqueeze(0) # import pdb # pdb.set_trace() simi_t = model.complex_score_e1_r_with_all_ementions( this_e1_real, this_e1_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) simi_h = model.complex_score_e2_r_with_all_ementions( this_e2_real, this_e2_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) # get known answers for filtered ranking ind = index[count] this_correct_mentions_e2 = test_kb.e2_all_answers[int( ind.item())] this_correct_mentions_e1 = test_kb.e1_all_answers[int( ind.item())] all_correct_mentions_e2 = all_known_e2.get( (em_map[test_kb.triples[int(ind.item())][0]], rm_map[test_kb.triples[int(ind.item())][1]]), []) all_correct_mentions_e1 = all_known_e1.get( (em_map[test_kb.triples[int(ind.item())][2]], rm_map[test_kb.triples[int(ind.item())][1]]), []) # compute metrics best_score = simi_t[this_correct_mentions_e2].max() simi_t[ all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_t.ge(best_score).float() equal = simi_t.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 metrics['mr_t'] += rank metrics['mrr_t'] += 1.0 / rank metrics['hits1_t'] += rank.le(1).float() metrics['hits10_t'] += rank.le(10).float() metrics['hits50_t'] += rank.le(50).float() best_score = simi_h[this_correct_mentions_e1].max() simi_h[ all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_h.ge(best_score).float() equal = simi_h.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 metrics['mr_h'] += rank metrics['mrr_h'] += 1.0 / rank metrics['hits1_h'] += rank.le(1).float() metrics['hits10_h'] += rank.le(10).float() metrics['hits50_h'] += rank.le(50).float() metrics['mr'] = (metrics['mr_h'] + metrics['mr_t']) / 2 metrics['mrr'] = (metrics['mrr_h'] + metrics['mrr_t']) / 2 metrics['hits1'] = (metrics['hits1_h'] + metrics['hits1_t']) / 2 metrics['hits10'] = (metrics['hits10_h'] + metrics['hits10_t']) / 2 metrics['hits50'] = (metrics['hits50_h'] + metrics['hits50_t']) / 2 for key in metrics: metrics[key] = metrics[key] / len(test_kb.triples) print(metrics)
def main(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # for batch in train_loader: # inputs, \ # normalizer_loss, \ # normalizer_metric, \ # labels, \ # label_ids, \ # filter_mask, \ # batch_shared_entities = train_data.input_and_labels_to_device( # batch, # training=True, # device=train_data.device # ) # import pdb # pdb.set_trace() # read token maps etokens, etoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "relation_token_id_map.txt")) entity_mentions, em_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt")) relation_mentions, rm_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt")) # create entity_token_indices and entity_lengths # [[max length indices for entity 0 ], [max length indices for entity 1], [max length indices for entity 2], ...] # [length of entity 0, length of entity 1, length of entity 2, ...] # entity_token_indices, entity_lengths = utils.get_token_indices_from_mention_indices(entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True) # relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(relation_mentions, rtoken_map, maxlen=args.max_seq_length, use_tqdm=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") #train code (+1 for unk token) if args.model == "complex": if args.separate_lstms: model = complexLSTM_2( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) else: model = complexLSTM( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) elif args.model == "rotate": model = rotatELSTM( len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, gamma=args.gamma_rotate, lstm_dropout=args.lstm_dropout) if args.do_eval: best_model = -1 best_metrics = None if "olpbench" in args.data_dir: # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map) test_kb = kb(os.path.join(args.data_dir, "test_data.txt"), em_map=em_map, rm_map=rm_map) else: test_kb = kb(os.path.join(args.data_dir, "test.txt"), em_map=em_map, rm_map=rm_map) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2, all_known_e1 = pickle.load( open( os.path.join( args.data_dir, "all_knowns_{}_linked.pkl".format(args.train_data_type)), "rb")) models = os.listdir("models/author_data_2lstm_thorough") for model_path in tqdm(models): try: model_path = os.path.join("models/author_data_2lstm_thorough", model_path) #eval code metrics = {} metrics['mr'] = 0 metrics['mrr'] = 0 metrics['hits1'] = 0 metrics['hits10'] = 0 metrics['hits50'] = 0 metrics['mr_t'] = 0 metrics['mrr_t'] = 0 metrics['hits1_t'] = 0 metrics['hits10_t'] = 0 metrics['hits50_t'] = 0 metrics['mr_h'] = 0 metrics['mrr_h'] = 0 metrics['hits1_h'] = 0 metrics['hits10_h'] = 0 metrics['hits50_h'] = 0 checkpoint = torch.load( model_path, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() # get embeddings for all entity mentions entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices( entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=False) entity_mentions_tensor = entity_mentions_tensor.cuda() entity_mentions_lengths = entity_mentions_lengths.cuda() ementions_real_lis = [] ementions_img_lis = [] split = 100 #cant fit all in gpu together. hence split with torch.no_grad(): for i in range(0, len(entity_mentions_tensor), len(entity_mentions_tensor) // split): data = entity_mentions_tensor[ i:i + len(entity_mentions_tensor) // split, :] data_lengths = entity_mentions_lengths[ i:i + len(entity_mentions_tensor) // split] ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding( data, 0, data_lengths) ementions_real_lis.append(ementions_real_lstm.cpu()) ementions_img_lis.append(ementions_img_lstm.cpu()) del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm torch.cuda.empty_cache() ementions_real = torch.cat(ementions_real_lis).cuda() ementions_img = torch.cat(ementions_img_lis).cuda() ######################################################################## test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length) test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length) test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length) # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map) indices = torch.Tensor( range(len(test_kb.triples)) ) #indices would be used to fetch alternative answers while evaluating test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) split_dim_for_eval = 1 if (args.embedding_dim >= 256 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 4 if (args.embedding_dim >= 512 and "olpbench" in args.data_dir): split_dim_for_eval = 4 if (args.embedding_dim >= 512 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 6 split_dim_for_eval = 1 for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in test_dataloader: test_e1_tokens, test_e1_lengths = test_e1_tokens.to( device), test_e1_lengths.to(device) test_r_tokens, test_r_lengths = test_r_tokens.to( device), test_r_lengths.to(device) test_e2_tokens, test_e2_lengths = test_e2_tokens.to( device), test_e2_lengths.to(device) with torch.no_grad(): e1_real_lstm, e1_img_lstm = model.get_mention_embedding( test_e1_tokens, 0, test_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding( test_r_tokens, 1, test_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding( test_e2_tokens, 0, test_e2_lengths) for count in range(index.shape[0]): # breakpoint() this_e1_real = e1_real_lstm[count].unsqueeze(0) this_e1_img = e1_img_lstm[count].unsqueeze(0) this_r_real = r_real_lstm[count].unsqueeze(0) this_r_img = r_img_lstm[count].unsqueeze(0) this_e2_real = e2_real_lstm[count].unsqueeze(0) this_e2_img = e2_img_lstm[count].unsqueeze(0) simi_t = model.complex_score_e1_r_with_all_ementions( this_e1_real, this_e1_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) simi_h = model.complex_score_e2_r_with_all_ementions( this_e2_real, this_e2_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) # get known answers for filtered ranking ind = index[count] this_correct_mentions_e2 = test_kb.e2_all_answers[int( ind.item())] this_correct_mentions_e1 = test_kb.e1_all_answers[int( ind.item())] all_correct_mentions_e2 = all_known_e2.get( (em_map[test_kb.triples[int(ind.item())][0]], rm_map[test_kb.triples[int(ind.item())][1]]), []) all_correct_mentions_e1 = all_known_e1.get( (em_map[test_kb.triples[int(ind.item())][2]], rm_map[test_kb.triples[int(ind.item())][1]]), []) # compute metrics best_score = simi_t[this_correct_mentions_e2].max() simi_t[ all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_t.ge(best_score).float() equal = simi_t.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 metrics['mr_t'] += rank metrics['mrr_t'] += 1.0 / rank metrics['hits1_t'] += rank.le(1).float() metrics['hits10_t'] += rank.le(10).float() metrics['hits50_t'] += rank.le(50).float() best_score = simi_h[this_correct_mentions_e1].max() simi_h[ all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_h.ge(best_score).float() equal = simi_h.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 metrics['mr_h'] += rank metrics['mrr_h'] += 1.0 / rank metrics['hits1_h'] += rank.le(1).float() metrics['hits10_h'] += rank.le(10).float() metrics['hits50_h'] += rank.le(50).float() metrics['mr'] = (metrics['mr_h'] + metrics['mr_t']) / 2 metrics['mrr'] = (metrics['mrr_h'] + metrics['mrr_t']) / 2 metrics['hits1'] = (metrics['hits1_h'] + metrics['hits1_t']) / 2 metrics['hits10'] = (metrics['hits10_h'] + metrics['hits10_t']) / 2 metrics['hits50'] = (metrics['hits50_h'] + metrics['hits50_t']) / 2 for key in metrics: metrics[key] = metrics[key] / len(test_kb.triples) if best_metrics == None or best_metrics['hits1'] < metrics[ 'hits1']: best_model = model_path best_metrics = metrics print("best_hits1:", best_metrics['hits1']) except: continue print(best_metrics) print(best_model)
etokens, etoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "relation_token_id_map.txt")) entity_mentions, em_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt")) relation_mentions, rm_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt")) if args.model == "complex": model = complexLSTM(len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=None, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=args.lstm_dropout) elif args.model == "rotate": model = rotatELSTM(len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=None, entity_tokens=etokens, relation_tokens=rtokens, gamma=args.gamma_rotate, lstm_dropout=args.lstm_dropout) if args.resume: print("Resuming from:", args.resume) checkpoint = torch.load(args.resume)
def main(args): hits_1_triple = [] hits_1_correct_answers = [] hits_1_model_top10 = [] hits_1_evidence = [] baseline_tail_hits1_indices = set([ 36, 91, 95, 101, 119, 158, 282, 397, 638, 728, 740, 763, 914, 959, 972, 992, 1184, 1478, 1669, 1686, 1732, 1795, 1796, 1822, 1826, 1845, 1924, 1939, 1943, 2055, 2178, 2317, 2319, 2325, 2482, 2513, 2589, 2627, 2674, 2736, 2862, 2985, 3049, 3311, 3327, 3491, 3660, 3728, 3817, 3818, 4111, 4263, 4387, 4437, 4438, 4452, 4525, 4591, 4670, 4856, 5114, 5159, 5318, 5587, 5851, 5857, 5893, 5925, 5942, 5990, 6056, 6079, 6119, 6172, 6195, 6211, 6228, 6262, 6267, 6460, 6491, 6509, 6584, 6676, 6699, 6862, 6982, 7057, 7078, 7084, 7221, 7597, 7733, 7837, 8045, 8278, 8326, 8380, 8433, 8453, 8479, 8534, 8540, 8742, 8813, 8860, 8906, 8930, 9234, 9333, 9500, 9535, 9589, 9663, 9803, 9809, 9866, 9999 ]) baseline_correct = 0 # nothits_50_triple = [] # nothits_50_correct_answers = [] # nothits_50_model_top10 = [] injected_rels = kb(args.evidence_file, em_map=None, rm_map=None).triples[:, 1].reshape(-1, args.n_times) # read token maps etokens, etoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map( os.path.join(args.data_dir, "mapped_to_ids", "relation_token_id_map.txt")) entity_mentions, em_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt")) _, rm_map = utils.read_mentions( os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt")) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) #train code (+1 for unk token) model = complexLSTM(len(etoken_map) + 1, len(rtoken_map) + 1, args.embedding_dim, initial_token_embedding=args.initial_token_embedding, entity_tokens=etokens, relation_tokens=rtokens, lstm_dropout=0) if (args.resume): print("Resuming from:", args.resume) checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() # get embeddings for all entity mentions entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices( entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True) entity_mentions_tensor = entity_mentions_tensor.cuda() entity_mentions_lengths = entity_mentions_lengths.cuda() ementions_real_lis = [] ementions_img_lis = [] split = 100 #cant fit all in gpu together. hence split with torch.no_grad(): for i in tqdm( range(0, len(entity_mentions_tensor), len(entity_mentions_tensor) // split)): data = entity_mentions_tensor[i:i + len(entity_mentions_tensor) // split, :] data_lengths = entity_mentions_lengths[ i:i + len(entity_mentions_tensor) // split] ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding( data, 0, data_lengths) ementions_real_lis.append(ementions_real_lstm.cpu()) ementions_img_lis.append(ementions_img_lstm.cpu()) del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm torch.cuda.empty_cache() ementions_real = torch.cat(ementions_real_lis).cuda() ementions_img = torch.cat(ementions_img_lis).cuda() ######################################################################## if "olpbench" in args.data_dir: test_kb = kb(os.path.join(args.data_dir, "test_data.txt"), em_map=em_map, rm_map=rm_map) else: test_kb = kb(os.path.join(args.data_dir, "test.txt"), em_map=em_map, rm_map=rm_map) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2, all_known_e1 = pickle.load( open(os.path.join(args.data_dir, "all_knowns_thorough_linked.pkl"), "rb")) test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length) test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length) test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices( test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length) # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map) indices = torch.Tensor( range(len(test_kb.triples)) ) #indices would be used to fetch alternative answers while evaluating test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) split_dim_for_eval = 1 if (args.embedding_dim >= 512 and "olpbench" in args.data_dir): split_dim_for_eval = 4 for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm( test_dataloader, desc="Test dataloader"): test_e1_tokens, test_e1_lengths = test_e1_tokens.to( device), test_e1_lengths.to(device) test_r_tokens, test_r_lengths = test_r_tokens.to( device), test_r_lengths.to(device) test_e2_tokens, test_e2_lengths = test_e2_tokens.to( device), test_e2_lengths.to(device) with torch.no_grad(): e1_real_lstm, e1_img_lstm = model.get_mention_embedding( test_e1_tokens, 0, test_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding( test_r_tokens, 1, test_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding( test_e2_tokens, 0, test_e2_lengths) for count in tqdm(range(index.shape[0]), desc="Evaluating"): this_e1_real = e1_real_lstm[count].unsqueeze(0) this_e1_img = e1_img_lstm[count].unsqueeze(0) this_r_real = r_real_lstm[count].unsqueeze(0) this_r_img = r_img_lstm[count].unsqueeze(0) this_e2_real = e2_real_lstm[count].unsqueeze(0) this_e2_img = e2_img_lstm[count].unsqueeze(0) # get known answers for filtered ranking ind = index[count] this_correct_mentions_e2 = test_kb.e2_all_answers[int(ind.item())] this_correct_mentions_e1 = test_kb.e1_all_answers[int(ind.item())] all_correct_mentions_e2 = all_known_e2.get( (em_map[test_kb.triples[int(ind.item())][0]], rm_map[test_kb.triples[int(ind.item())][1]]), []) all_correct_mentions_e1 = all_known_e1.get( (em_map[test_kb.triples[int(ind.item())][2]], rm_map[test_kb.triples[int(ind.item())][1]]), []) if (args.head_or_tail == "tail"): simi = model.complex_score_e1_r_with_all_ementions( this_e1_real, this_e1_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) best_score = simi[this_correct_mentions_e2].max() simi[ all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE greatereq = simi.ge(best_score).float() equal = simi.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 else: simi = model.complex_score_e2_r_with_all_ementions( this_e2_real, this_e2_img, this_r_real, this_r_img, ementions_real, ementions_img, split=split_dim_for_eval).squeeze(0) best_score = simi[this_correct_mentions_e1].max() simi[ all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE greatereq = simi.ge(best_score).float() equal = simi.eq(best_score).float() rank = greatereq.sum() + 1 + equal.sum() / 2.0 if int(ind.item()) in baseline_tail_hits1_indices: if rank <= 1: baseline_correct += 1 continue if (rank <= 1): #hits1 hits_1_triple.append([ test_kb.triples[int(ind.item())][0], test_kb.triples[int(ind.item())][1], test_kb.triples[int(ind.item())][2] ]) hits_1_evidence.append(injected_rels[int(ind.item())].tolist()) if (args.head_or_tail == "tail"): # hits_1_correct_answers.append(this_correct_mentions_e2) hits_1_correct_answers.append( [entity_mentions[x] for x in this_correct_mentions_e2]) else: hits_1_correct_answers.append( [entity_mentions[x] for x in this_correct_mentions_e1]) hits_1_model_top10.append([]) # elif(rank>50): # #nothits50 # nothits_50_triple.append([test_kb.triples[int(ind.item())][0],test_kb.triples[int(ind.item())][1],test_kb.triples[int(ind.item())][2]]) # if(args.head_or_tail=="tail"): # nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e2]) # else: # nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e1]) # tmp = simi.sort()[1].tolist()[::-1][:10] # nothits_50_model_top10.append([entity_mentions[x] for x in tmp]) indices = list(range(len(hits_1_triple))) random.shuffle(indices) indices = indices[:args.sample] print(baseline_correct) for ind in indices: print(ind, "|", hits_1_triple[ind], "|", hits_1_correct_answers[ind], "|", hits_1_model_top10[ind], "|", hits_1_evidence[ind])
def main(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # for batch in train_loader: # inputs, \ # normalizer_loss, \ # normalizer_metric, \ # labels, \ # label_ids, \ # filter_mask, \ # batch_shared_entities = train_data.input_and_labels_to_device( # batch, # training=True, # device=train_data.device # ) # import pdb # pdb.set_trace() # read token maps etokens, etoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","entity_token_id_map.txt")) rtokens, rtoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","relation_token_id_map.txt")) entity_mentions,em_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","entity_id_map.txt")) relation_mentions,rm_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","relation_id_map.txt")) # create entity_token_indices and entity_lengths # [[max length indices for entity 0 ], [max length indices for entity 1], [max length indices for entity 2], ...] # [length of entity 0, length of entity 1, length of entity 2, ...] entity_token_indices, entity_lengths = utils.get_token_indices_from_mention_indices(entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True) relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(relation_mentions, rtoken_map, maxlen=args.max_seq_length, use_tqdm=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not args.do_train and not args.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") #train code (+1 for unk token) if args.model=="complex": if args.separate_lstms: model = complexLSTM_2(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, lstm_dropout=args.lstm_dropout) else: model = complexLSTM(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, lstm_dropout=args.lstm_dropout) elif args.model == "rotate": model = rotatELSTM(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, gamma = args.gamma_rotate, lstm_dropout=args.lstm_dropout) if(args.do_train): data_config = {'input_file': 'train_data_thorough.txt', 'batch_size': args.train_batch_size, 'use_batch_shared_entities': True, 'min_size_batch_labels': args.train_batch_size, 'max_size_prefix_label': 64, 'device': 0} expt_settings = {'loss': 'bce', 'replace_entities_by_tokens': True, 'replace_relations_by_tokens': True, 'max_lengths_tuple': [10, 10]} train_data = OneToNMentionRelationDataset(dataset_dir=os.path.join(args.data_dir,"mapped_to_ids"), is_training_data=True, **data_config, **expt_settings) train_data.create_data_tensors( dataset_dir=os.path.join(args.data_dir,"mapped_to_ids"), train_input_file='train_data_thorough.txt', valid_input_file='validation_data_linked.txt', test_input_file='test_data.txt', ) train_loader = train_data.get_loader( shuffle=True, num_workers=8, drop_last=True, ) optimizer = torch.optim.Adagrad(model.parameters(),lr=args.learning_rate,weight_decay=args.weight_decay) if(args.resume): print("Resuming from:",args.resume) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) #Load other things too if required model.train() # if "olpbench" in args.data_dir: # train_kb = kb(os.path.join(args.data_dir,"train_data_{}.txt".format(args.train_data_type)), em_map = em_map, rm_map = rm_map) # # train_kb = kb(os.path.join(args.data_dir,"train_data_thorough_r_sorted.txt"), em_map = em_map, rm_map = rm_map) # # train_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map) # else: # train_kb = kb(os.path.join(args.data_dir,"train.txt"), em_map = em_map, rm_map = rm_map) # train_data = Dataset(train_kb.triples) # train_sampler = RandomSampler(train_data,replacement=False) # #train_sampler = SequentialSampler(train_data) # train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) # crossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='mean') BCEloss = torch.nn.BCEWithLogitsLoss(reduction='sum') for epoch in tqdm(range(0,args.num_train_epochs), desc="epoch"): iteration = 0 for batch in tqdm(train_loader, desc="Train dataloader"): inputs, \ normalizer_loss, \ normalizer_metric, \ labels, \ label_ids, \ filter_mask, \ batch_shared_entities = train_data.input_and_labels_to_device( batch, training=True, device="cpu" ) labels = labels.cuda() all_outputs = [] for mode,model_inputs in zip(["head","tail"],inputs): if model_inputs==None: continue # subtract two from author's indices because our map is 2 less if mode=="head": batch_e2_indices = model_inputs[1] - 2 batch_r_indices = model_inputs[0] - 2 batch_e1_indices = batch_shared_entities - 2 else: batch_e1_indices = model_inputs[0] - 2 batch_r_indices = model_inputs[1] - 2 batch_e2_indices = batch_shared_entities - 2 # import pdb # pdb.set_trace() # convert these indices back into string (compatibility with my code) # tik = time.time() # batch_e1_strings = convert_mention_index_to_string(batch_e1_indices.squeeze(1), entity_mentions) # batch_r_strings = convert_mention_index_to_string(batch_r_indices.squeeze(1), relation_mentions) # batch_e2_strings = convert_mention_index_to_string(batch_e2_indices.squeeze(1), entity_mentions) # # print("convert_mention_index_to_string:",time.time() - tik) # # do what you used to do now # # tik = time.time() # train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices(batch_e1_strings,etoken_map,maxlen=args.max_seq_length) # train_r_mention_tensor, train_r_lengths = convert_string_to_indices(batch_r_strings,rtoken_map,maxlen=args.max_seq_length) # train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices(batch_e2_strings,etoken_map,maxlen=args.max_seq_length) # print("convert_string_to_indices:",time.time() - tik) train_e1_mention_tensor, train_e1_lengths = convert_mention_to_token_indices(batch_e1_indices.squeeze(1), entity_token_indices, entity_lengths) train_r_mention_tensor, train_r_lengths = convert_mention_to_token_indices(batch_r_indices.squeeze(1), relation_token_indices, relation_lengths) train_e2_mention_tensor, train_e2_lengths = convert_mention_to_token_indices(batch_e2_indices.squeeze(1), entity_token_indices, entity_lengths) train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(), train_e1_lengths.cuda() train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda(), train_r_lengths.cuda() train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(), train_e2_lengths.cuda() # tik = time.time() e1_real_lstm, e1_img_lstm = model.get_mention_embedding(train_e1_mention_tensor,0,train_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding(train_r_mention_tensor,1,train_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding(train_e2_mention_tensor,0,train_e2_lengths) # print("get_mention_embedding:",time.time() - tik) # tik = time.time() if mode=="head": output = model.complex_score_e2_r_with_all_ementions(e2_real_lstm,e2_img_lstm,r_real_lstm,r_img_lstm,e1_real_lstm,e1_img_lstm) else: output = model.complex_score_e1_r_with_all_ementions(e1_real_lstm,e1_img_lstm,r_real_lstm,r_img_lstm,e2_real_lstm,e2_img_lstm) # print("model_scoring:",time.time() - tik) all_outputs.append(output) all_outputs = torch.cat(all_outputs) loss = BCEloss(all_outputs.view(-1),labels.view(-1)) # loss = loss.sum() loss /= normalizer_loss optimizer.zero_grad() loss.backward() optimizer.step() if(iteration%args.print_loss_every==0): print("Current loss:",loss.item()) iteration+=1 if(epoch%args.save_model_every==0): utils.save_checkpoint({ 'state_dict':model.state_dict(), 'optimizer':optimizer.state_dict() },args.output_dir+"/checkpoint_epoch_{}".format(epoch+1)) # for epoch in tqdm(range(0,args.num_train_epochs),desc="epoch"): # iteration = 0 # for train_e1_batch, train_r_batch, train_e2_batch in tqdm(train_dataloader,desc="Train dataloader"): # # skip this batch # if(random.random()<args.skip_train_prob): # continue # batch_size = len(train_e1_batch) # train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices(train_e1_batch,etoken_map,maxlen=args.max_seq_length) # train_r_mention_tensor, train_r_lengths = convert_string_to_indices(train_r_batch,rtoken_map,maxlen=args.max_seq_length) # train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices(train_e2_batch,etoken_map,maxlen=args.max_seq_length) # train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(), train_e1_lengths.cuda() # train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda(), train_r_lengths.cuda() # train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(), train_e2_lengths.cuda() # e1_real_lstm, e1_img_lstm = model.get_mention_embedding(train_e1_mention_tensor,0,train_e1_lengths) # r_real_lstm, r_img_lstm = model.get_mention_embedding(train_r_mention_tensor,1,train_r_lengths) # e2_real_lstm, e2_img_lstm = model.get_mention_embedding(train_e2_mention_tensor,0,train_e2_lengths) # #tail # simi_t = model.complex_score_e1_r_with_all_ementions(e1_real_lstm,e1_img_lstm,r_real_lstm,r_img_lstm,e2_real_lstm,e2_img_lstm) # #head # simi_h = model.complex_score_e2_r_with_all_ementions(e2_real_lstm,e2_img_lstm,r_real_lstm,r_img_lstm,e1_real_lstm,e1_img_lstm) # # change the loss suitably # target = torch.eye(batch_size).cuda() # # import pdb # # pdb.set_trace() # loss_t = BCEloss(simi_t.view(-1),target.view(-1)) # loss_h = BCEloss(simi_h.view(-1),target.view(-1)) # loss = (loss_h+loss_t)/2 # loss /= target.size(0) * target.size(1) # # Do the routine # optimizer.zero_grad() # loss.backward() # #gradient clip? # optimizer.step() # if(iteration%args.print_loss_every==0): # print("Current loss(avg, tail, head):",loss.item(), loss_t.item(), loss_h.item()) # iteration+=1 # if(epoch%args.save_model_every==0 and epoch!=0): # utils.save_checkpoint({ # 'state_dict':model.state_dict(), # 'optimizer':optimizer.state_dict() # },args.output_dir+"/checkpoint_epoch_{}".format(epoch+1)) if args.do_eval: #eval code metrics = {} metrics['mr'] = 0 metrics['mrr'] = 0 metrics['hits1'] = 0 metrics['hits10'] = 0 metrics['hits50'] = 0 metrics['mr_t'] = 0 metrics['mrr_t'] = 0 metrics['hits1_t'] = 0 metrics['hits10_t'] = 0 metrics['hits50_t'] = 0 metrics['mr_h'] = 0 metrics['mrr_h'] = 0 metrics['hits1_h'] = 0 metrics['hits10_h'] = 0 metrics['hits50_h'] = 0 if(args.resume and not args.do_train): print("Resuming from:",args.resume) checkpoint = torch.load(args.resume,map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() # get embeddings for all entity mentions entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(entity_mentions,etoken_map,maxlen=args.max_seq_length,use_tqdm=True) entity_mentions_tensor = entity_mentions_tensor.cuda() entity_mentions_lengths = entity_mentions_lengths.cuda() ementions_real_lis = [] ementions_img_lis = [] split = 100 #cant fit all in gpu together. hence split with torch.no_grad(): for i in tqdm(range(0,len(entity_mentions_tensor),len(entity_mentions_tensor)//split)): data = entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:] data_lengths = entity_mentions_lengths[i:i+len(entity_mentions_tensor)//split] ementions_real_lstm,ementions_img_lstm = model.get_mention_embedding(data,0,data_lengths) # a = model.Et_im(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]) # b = model.Et_re(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]) # a_lstm,_ = model.lstm(a) # a_lstm = a_lstm[:,-1,:] # b_lstm,_ = model.lstm(b) # b_lstm = b_lstm[:,-1,:] ementions_real_lis.append(ementions_real_lstm.cpu()) ementions_img_lis.append(ementions_img_lstm.cpu()) del entity_mentions_tensor,ementions_real_lstm,ementions_img_lstm torch.cuda.empty_cache() ementions_real = torch.cat(ementions_real_lis).cuda() ementions_img = torch.cat(ementions_img_lis).cuda() ######################################################################## if "olpbench" in args.data_dir: # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map) test_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map) else: test_kb = kb(os.path.join(args.data_dir,"test.txt"), em_map = em_map, rm_map = rm_map) print("Loading all_known pickled data...(takes times since large)") all_known_e2 = {} all_known_e1 = {} all_known_e2,all_known_e1 = pickle.load(open(os.path.join(args.data_dir,"all_knowns_{}_linked.pkl".format(args.train_data_type)),"rb")) test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(test_kb.triples[:,0], etoken_map,maxlen=args.max_seq_length) test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(test_kb.triples[:,1], rtoken_map,maxlen=args.max_seq_length) test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(test_kb.triples[:,2], etoken_map,maxlen=args.max_seq_length) # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map) indices = torch.Tensor(range(len(test_kb.triples))) #indices would be used to fetch alternative answers while evaluating test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths) test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size) split_dim_for_eval = 1 if(args.embedding_dim>=256 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 4 if(args.embedding_dim>=512 and "olpbench" in args.data_dir): split_dim_for_eval = 4 if(args.embedding_dim>=512 and "olpbench" in args.data_dir and "rotat" in args.model): split_dim_for_eval = 6 split_dim_for_eval = 1 for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(test_dataloader,desc="Test dataloader"): print(metrics) test_e1_tokens, test_e1_lengths = test_e1_tokens.to(device), test_e1_lengths.to(device) test_r_tokens, test_r_lengths = test_r_tokens.to(device), test_r_lengths.to(device) test_e2_tokens, test_e2_lengths = test_e2_tokens.to(device), test_e2_lengths.to(device) with torch.no_grad(): e1_real_lstm, e1_img_lstm = model.get_mention_embedding(test_e1_tokens,0, test_e1_lengths) r_real_lstm, r_img_lstm = model.get_mention_embedding(test_r_tokens,1, test_r_lengths) e2_real_lstm, e2_img_lstm = model.get_mention_embedding(test_e2_tokens,0, test_e2_lengths) for count in tqdm(range(index.shape[0]), desc="Evaluating"): # breakpoint() this_e1_real = e1_real_lstm[count].unsqueeze(0) this_e1_img = e1_img_lstm[count].unsqueeze(0) this_r_real = r_real_lstm[count].unsqueeze(0) this_r_img = r_img_lstm[count].unsqueeze(0) this_e2_real = e2_real_lstm[count].unsqueeze(0) this_e2_img = e2_img_lstm[count].unsqueeze(0) simi_t = model.complex_score_e1_r_with_all_ementions(this_e1_real,this_e1_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0) simi_h = model.complex_score_e2_r_with_all_ementions(this_e2_real,this_e2_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0) # get known answers for filtered ranking ind = index[count] this_correct_mentions_e2 = test_kb.e2_all_answers[int(ind.item())] this_correct_mentions_e1 = test_kb.e1_all_answers[int(ind.item())] all_correct_mentions_e2 = all_known_e2.get((em_map[test_kb.triples[int(ind.item())][0]],rm_map[test_kb.triples[int(ind.item())][1]]),[]) all_correct_mentions_e1 = all_known_e1.get((em_map[test_kb.triples[int(ind.item())][2]],rm_map[test_kb.triples[int(ind.item())][1]]),[]) # compute metrics best_score = simi_t[this_correct_mentions_e2].max() simi_t[all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_t.ge(best_score).float() equal = simi_t.eq(best_score).float() rank = greatereq.sum()+1+equal.sum()/2.0 metrics['mr_t'] += rank metrics['mrr_t'] += 1.0/rank metrics['hits1_t'] += rank.le(1).float() metrics['hits10_t'] += rank.le(10).float() metrics['hits50_t'] += rank.le(50).float() best_score = simi_h[this_correct_mentions_e1].max() simi_h[all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE greatereq = simi_h.ge(best_score).float() equal = simi_h.eq(best_score).float() rank = greatereq.sum()+1+equal.sum()/2.0 metrics['mr_h'] += rank metrics['mrr_h'] += 1.0/rank metrics['hits1_h'] += rank.le(1).float() metrics['hits10_h'] += rank.le(10).float() metrics['hits50_h'] += rank.le(50).float() metrics['mr'] = (metrics['mr_h']+metrics['mr_t'])/2 metrics['mrr'] = (metrics['mrr_h']+metrics['mrr_t'])/2 metrics['hits1'] = (metrics['hits1_h']+metrics['hits1_t'])/2 metrics['hits10'] = (metrics['hits10_h']+metrics['hits10_t'])/2 metrics['hits50'] = (metrics['hits50_h']+metrics['hits50_t'])/2 for key in metrics: metrics[key] = metrics[key] / len(test_kb.triples) print(metrics)