示例#1
0
def main():
	data_dir = "data/olpbench"
	entity_mentions,em_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","entity_id_map.txt"))
	relation_mentions,rm_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","relation_id_map.txt"))
	random.seed(42)
	np.random.seed(42)
	torch.manual_seed(42)
	test_kb = kb(os.path.join(data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map)
	xt_lines = open("helper_scripts/tmp/test_data_preds.txt.tail_thorough_f5_d300_e50.stage1",'r').readlines()
	cache_e = pickle.load(open("helper_scripts/tmp/top_1000_neighbors_val.pkl",'rb'))

	test_kb.triples = np.delete(test_kb.triples,4996,0)
	del test_kb.e1_all_answers[4996]
	del test_kb.e2_all_answers[4996]

	answers_t = []
	for line in tqdm(xt_lines, desc="nudging"):
		line = line.strip().split("\t")
		e1 = em_map[line[0]]
		e1_neighbors = cache_e.get(e1,[]) # [[12,12312],[211,2312],...]
		tmp_answers_t = ast.literal_eval(line[-1])
		for neighbor in e1_neighbors:
			for i in range(len(tmp_answers_t)):
				if neighbor[0]==tmp_answers_t[i][0]:
					tmp_answers_t[i][1] += neighbor[1]
		answers_t.append(tmp_answers_t)

	result = utils.get_metrics_using_topk(os.path.join(data_dir,"all_knowns_thorough_linked.pkl"),test_kb,answers_t,answers_t,em_map,rm_map)
	print(result)
def main():
	K = 1000
	data_dir = "data/olpbench"
	entity_mentions,em_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","entity_id_map.txt"))
	relation_mentions,rm_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","relation_id_map.txt"))
	random.seed(42)
	np.random.seed(42)
	torch.manual_seed(42)
	test_kb = kb(os.path.join(data_dir,"validation_data_linked.txt"), em_map = em_map, rm_map = rm_map)
	cache_e = pickle.load(open("helper_scripts/tmp/top_1000_neighbors_val.pkl",'rb'))
	if (0):
		graph = pickle.load(open("helper_scripts/tmp/graph_thorough_no_r_count_paths.pkl",'rb'))
		if(0):
			train_kb = kb(os.path.join(data_dir,"train_data_thorough.txt"), em_map = None, rm_map = None)
			graph = {}

			f = open("helper_scripts/tmp/graph_thorough_no_r_count_paths.pkl",'wb')
			for triple in tqdm(train_kb.triples):
				e1 = em_map[triple[0].item()]
				r  = rm_map[triple[1].item()]
				e2 = em_map[triple[2].item()]

				if e1 not in graph:
					graph[e1] = defaultdict(int)
				graph[e1][e2] += 1

				if e2 not in graph:
					graph[e2] = defaultdict(int)
				graph[e2][e1] += 1

			pickle.dump(graph,f)
			f.close()
			exit()
		manager = Manager()
		cache_e = manager.dict()
		process_array = []
		for i in range(10):
			p = mp.Process(target = func, args = (cache_e,test_kb,i,1000,em_map,rm_map,graph,K))
			p.start()
			process_array.append(p)
		for p in process_array:
			p.join()
		cache_e_final = {}
		for key in cache_e:
			cache_e_final[key] = cache_e[key]
		f = open("helper_scripts/tmp/top_10000_neighbors_val.pkl",'wb')
		pickle.dump(cache_e_final,f)
		f.close()
		exit()
	answers_t = []
	answers_h = []
	for ind,triple in enumerate(test_kb.triples):
		e1 = em_map[triple[0].item()]
		r  = rm_map[triple[1].item()]
		e2 = em_map[triple[2].item()]
		answers_t.append(cache_e[e1])
		answers_h.append(cache_e[e2])
	
	metrics = utils.get_metrics_using_topk(os.path.join(data_dir,"all_knowns_thorough_linked.pkl"),test_kb,answers_t,answers_h,em_map,rm_map)
	print(metrics)
示例#3
0
def main():
    data_dir = "../olpbench"
    freq_r_tail = {}
    freq_r_head = {}
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(data_dir, "mapped_to_ids", "entity_id_map.txt"))
    _, rm_map = utils.read_mentions(
        os.path.join(data_dir, "mapped_to_ids", "relation_id_map.txt"))

    # train_kb = kb(os.path.join(data_dir,"test_data.txt"), em_map = None, rm_map = None)
    train_kb = kb(os.path.join(data_dir, "train_data_thorough.txt"),
                  em_map=None,
                  rm_map=None)
    for triple in tqdm(train_kb.triples, desc="getting r freq"):
        e1 = triple[0].item()
        r = triple[1].item()
        e2 = triple[2].item()
        if r not in freq_r_tail:
            freq_r_tail[r] = {}
        if em_map[e2] not in freq_r_tail[r]:
            freq_r_tail[r][em_map[e2]] = 0
        freq_r_tail[r][em_map[e2]] += 1

        if r not in freq_r_head:
            freq_r_head[r] = {}
        if em_map[e1] not in freq_r_head[r]:
            freq_r_head[r][em_map[e1]] = 0
        freq_r_head[r][em_map[e1]] += 1

    f = open("../olpbench/r-freq_top100_thorough_head.pkl", "wb")
    final_data = {}
    for r in freq_r_head:
        final_list = list(
            zip(list(freq_r_head[r].values()), list(freq_r_head[r].keys())))
        final_list.sort(reverse=True)
        final_list = final_list[:100]
        final_data[r] = final_list
    pickle.dump(final_data, f)
    f.close()

    f = open("../olpbench/r-freq_top100_thorough_tail.pkl", "wb")
    final_data = {}
    for r in freq_r_tail:
        final_list = list(
            zip(list(freq_r_tail[r].values()), list(freq_r_tail[r].keys())))
        final_list.sort(reverse=True)
        final_list = final_list[:100]
        final_data[r] = final_list
    pickle.dump(final_data, f)
    f.close()
示例#4
0
def main():
    data_dir = "data/olpbench"
    freq_r_tail = {}
    freq_r_head = {}
    entity_mentions,em_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","entity_id_map.txt"))
    _,rm_map = utils.read_mentions(os.path.join(data_dir,"mapped_to_ids","relation_id_map.txt"))

    train_kb = kb(os.path.join(data_dir,"train_data_thorough.txt"), em_map = None, rm_map = None)
    for triple in tqdm(train_kb.triples, desc="getting r freq"):
        e1 = triple[0].item()
        r  = triple[1].item()
        e2 = triple[2].item()
        if r not in freq_r_tail:
            freq_r_tail[r] = {}
        if em_map[e2] not in freq_r_tail[r]:
            freq_r_tail[r][em_map[e2]] = 0
        freq_r_tail[r][em_map[e2]] += 1

        if r not in freq_r_head:
            freq_r_head[r] = {}
        if em_map[e1] not in freq_r_head[r]:
            freq_r_head[r][em_map[e1]] = 0
        freq_r_head[r][em_map[e1]] += 1


    test_kb = kb(os.path.join(data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map)
    answers_t = []
    answers_h = []

    for triple in test_kb.triples:
        r = triple[1].item()
        val = freq_r_tail.get(r,{})
        this_answer = []
        for key in val:
            this_answer.append([key,val[key]])
        answers_t.append(this_answer)

        val = freq_r_head.get(r,{})
        this_answer = []
        for key in val:
            this_answer.append([key,val[key]])
        answers_h.append(this_answer)
    metrics = utils.get_metrics_using_topk(os.path.join(data_dir,"all_knowns_thorough_linked.pkl"),test_kb,answers_t,answers_h,em_map,rm_map)
    print(metrics)
	def load_embedding(self):
		# return torch.nn.Embedding(self.entity_count,512)

		# Step 1 get train data
		# DATA_DIR = "/home/mayank/olpbench"
		DATA_DIR = "/home/yatin/mayank/olpbench"
		etokens, etoken_map = utils.get_tokens_map(os.path.join(DATA_DIR,"mapped_to_ids","entity_token_id_map.txt"))
		rtokens, rtoken_map = utils.get_tokens_map(os.path.join(DATA_DIR,"mapped_to_ids","relation_token_id_map.txt"))
		entity_mentions,em_map = utils.read_mentions(os.path.join(DATA_DIR,"mapped_to_ids","entity_id_map.txt"))
		relation_mentions,rm_map = utils.read_mentions(os.path.join(DATA_DIR,"mapped_to_ids","relation_id_map.txt"))
		train_kb = kb(os.path.join(DATA_DIR,"train_data_thorough.txt"), em_map = em_map, rm_map = rm_map)

		# Step 2 get those 2 helper things for relation
		relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(relation_mentions, rtoken_map, maxlen=10, use_tqdm=True)

		# Step 3 for each entity get the top frequent relation
		freq_e = {}
		for triple in tqdm(train_kb.triples):
			# e1 = triple[0].item()
			# r  = triple[1].item()
			# e2 = triple[2].item()
			e1 = em_map[triple[0].item()]
			r  = rm_map[triple[1].item()]
			e2 = em_map[triple[2].item()]
			if e1 not in freq_e:
				freq_e[e1] = {}
			if r not in freq_e[e1]:
				freq_e[e1][r] = 0
			freq_e[e1][r] += 1
			if e2 not in freq_e:
				freq_e[e2] = {}
			if r not in freq_e[e2]:
				freq_e[e2][r] = 0
			freq_e[e2][r] += 1
		for key in tqdm(freq_e):
			freq_e[key] = max(freq_e[key], key = freq_e[key].get)


		# Step 4 get the embedding for that relation and save it in torch.nn.embedding against that entity
		from models import complexLSTM_2_all_e
		model = complexLSTM_2_all_e(196007,39303, 2473409, 512, lstm_dropout=0.1)
		print("Resuming...")
		# checkpoint = torch.load("/home/mayank/olpbench/models/author_data_2lstm_thorough_all-e/checkpoint_epoch_43",map_location="cpu")
		checkpoint = torch.load("/home/yatin/mayank/olpbench/models/checkpoint_epoch_43",map_location="cpu")
		model.load_state_dict(checkpoint['state_dict'])
		embedding = torch.nn.Embedding(self.entity_count,512)
		model.eval()
		model.to("cuda")
		for entity in trange(2473409, desc="Creating entity tensors finally!"):
			# import pdb
			# pdb.set_trace()
			if entity not in freq_e:
				embedding.weight.data[entity] = torch.zeros(512)
			else:
				r_mention_tensor, r_lengths = convert_mention_to_token_indices([freq_e[entity]], relation_token_indices, relation_lengths)
				r_mention_tensor, r_lengths = r_mention_tensor.cuda(), r_lengths.cuda()
				r_real_lstm, r_img_lstm     = model.get_mention_embedding(r_mention_tensor,1,r_lengths)
				r_real_lstm = r_real_lstm[0]
				r_img_lstm  = r_img_lstm[0]
				embedding.weight.data[entity] = torch.cat([r_real_lstm, r_img_lstm]).cpu()
		# import pdb
		# pdb.set_trace()
		return embedding
示例#6
0
def main():
    K = 2
    data_dir = "data/olpbench"
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(data_dir, "mapped_to_ids", "entity_id_map.txt"))
    relation_mentions, rm_map = utils.read_mentions(
        os.path.join(data_dir, "mapped_to_ids", "relation_id_map.txt"))

    graph = pickle.load(open("helper_scripts/tmp/graph_thorough.pkl", 'rb'))
    if (0):
        train_kb = kb(os.path.join(data_dir, "train_data_thorough.txt"),
                      em_map=None,
                      rm_map=None)
        # train_kb = kb(os.path.join(data_dir,"test_data.txt"), em_map = None, rm_map = None)
        # train_kb = kb(os.path.join(data_dir,"validation_data_linked.txt"), em_map = None, rm_map = None)

        graph = {}
        graph_forward = {}
        graph_backward = {}
        rels_for_e1_e2 = {}

        for triple in tqdm(train_kb.triples):
            e1 = em_map[triple[0].item()]
            r = rm_map[triple[1].item()]
            e2 = em_map[triple[2].item()]

            if e1 not in graph:
                graph[e1] = []
            graph[e1].append([e2, r])

            if e2 not in graph:
                graph[e2] = []
            graph[e2].append([e1, len(relation_mentions) + r])

            # ----------------------------------------------------
            # if (e1,e2) not in rels_for_e1_e2:
            # 	rels_for_e1_e2[(e1,e2)] = []
            # rels_for_e1_e2[(e1,e2)].append(r)

            # if e1 not in graph_forward:
            # 	graph_forward[e1] = set()
            # # graph_forward[e1].add((e2,r))
            # graph_forward[e1].add(e2)

            # if e2 not in graph_backward:
            # 	graph_backward[e2] = set()
            # # graph_backward[e2].add((e1,r))
            # graph_backward[e2].add(e1)

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    test_kb = kb(os.path.join(data_dir, "validation_data_linked.txt"),
                 em_map=None,
                 rm_map=None)
    count = 0

    two_hop_data = []
    for triple in tqdm(test_kb.triples, desc="test triples"):
        e1 = em_map[triple[0].item()]
        r = rm_map[triple[1].item()]
        e2 = em_map[triple[2].item()]

        if bfs(e1, e2, K, graph):
            count += 1

        # flag, proof = forw_back_check(e1,e2,graph_forward,graph_backward,rels_for_e1_e2)
        # if flag:
        # 	count += 1
        # 	two_hop_data.append(proof)

    print(count)
示例#7
0
def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # for batch in train_loader:
    # 	inputs, \
    #        normalizer_loss, \
    #        normalizer_metric, \
    #        labels, \
    #        label_ids, \
    #        filter_mask, \
    #        batch_shared_entities = train_data.input_and_labels_to_device(
    #            batch,
    #            training=True,
    #            device=train_data.device
    #        )
    # 	import pdb
    # 	pdb.set_trace()

    # read token maps
    etokens, etoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "entity_token_id_map.txt"))
    rtokens, rtoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "relation_token_id_map.txt"))
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt"))
    relation_mentions, rm_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt"))

    # create entity_token_indices and entity_lengths
    # [[max length indices for entity 0 ], [max length indices for entity 1], [max length indices for entity 2], ...]
    # [length of entity 0, length of entity 1, length of entity 2, ...]
    # entity_token_indices, entity_lengths = utils.get_token_indices_from_mention_indices(entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True)
    # relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(relation_mentions, rtoken_map, maxlen=args.max_seq_length, use_tqdm=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    #train code (+1 for unk token)
    if args.model == "complex":
        if args.separate_lstms:
            model = complexLSTM_2(
                len(etoken_map) + 1,
                len(rtoken_map) + 1,
                args.embedding_dim,
                initial_token_embedding=args.initial_token_embedding,
                entity_tokens=etokens,
                relation_tokens=rtokens,
                lstm_dropout=args.lstm_dropout)
        else:
            model = complexLSTM(
                len(etoken_map) + 1,
                len(rtoken_map) + 1,
                args.embedding_dim,
                initial_token_embedding=args.initial_token_embedding,
                entity_tokens=etokens,
                relation_tokens=rtokens,
                lstm_dropout=args.lstm_dropout)
    elif args.model == "rotate":
        model = rotatELSTM(
            len(etoken_map) + 1,
            len(rtoken_map) + 1,
            args.embedding_dim,
            initial_token_embedding=args.initial_token_embedding,
            entity_tokens=etokens,
            relation_tokens=rtokens,
            gamma=args.gamma_rotate,
            lstm_dropout=args.lstm_dropout)

    if args.do_eval:
        best_model = -1
        best_metrics = None
        if "olpbench" in args.data_dir:
            # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map)
            test_kb = kb(os.path.join(args.data_dir, "test_data.txt"),
                         em_map=em_map,
                         rm_map=rm_map)
        else:
            test_kb = kb(os.path.join(args.data_dir, "test.txt"),
                         em_map=em_map,
                         rm_map=rm_map)

        print("Loading all_known pickled data...(takes times since large)")
        all_known_e2 = {}
        all_known_e1 = {}
        all_known_e2, all_known_e1 = pickle.load(
            open(
                os.path.join(
                    args.data_dir,
                    "all_knowns_{}_linked.pkl".format(args.train_data_type)),
                "rb"))
        models = os.listdir("models/author_data_2lstm_thorough")
        for model_path in tqdm(models):
            try:
                model_path = os.path.join("models/author_data_2lstm_thorough",
                                          model_path)
                #eval code
                metrics = {}
                metrics['mr'] = 0
                metrics['mrr'] = 0
                metrics['hits1'] = 0
                metrics['hits10'] = 0
                metrics['hits50'] = 0
                metrics['mr_t'] = 0
                metrics['mrr_t'] = 0
                metrics['hits1_t'] = 0
                metrics['hits10_t'] = 0
                metrics['hits50_t'] = 0
                metrics['mr_h'] = 0
                metrics['mrr_h'] = 0
                metrics['hits1_h'] = 0
                metrics['hits10_h'] = 0
                metrics['hits50_h'] = 0

                checkpoint = torch.load(
                    model_path, map_location=lambda storage, loc: storage)
                model.load_state_dict(checkpoint['state_dict'])

                model.eval()

                # get embeddings for all entity mentions
                entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(
                    entity_mentions,
                    etoken_map,
                    maxlen=args.max_seq_length,
                    use_tqdm=False)
                entity_mentions_tensor = entity_mentions_tensor.cuda()
                entity_mentions_lengths = entity_mentions_lengths.cuda()

                ementions_real_lis = []
                ementions_img_lis = []
                split = 100  #cant fit all in gpu together. hence split
                with torch.no_grad():
                    for i in range(0, len(entity_mentions_tensor),
                                   len(entity_mentions_tensor) // split):
                        data = entity_mentions_tensor[
                            i:i + len(entity_mentions_tensor) // split, :]
                        data_lengths = entity_mentions_lengths[
                            i:i + len(entity_mentions_tensor) // split]
                        ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding(
                            data, 0, data_lengths)
                        ementions_real_lis.append(ementions_real_lstm.cpu())
                        ementions_img_lis.append(ementions_img_lstm.cpu())
                del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm
                torch.cuda.empty_cache()
                ementions_real = torch.cat(ementions_real_lis).cuda()
                ementions_img = torch.cat(ementions_img_lis).cuda()
                ########################################################################

                test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(
                    test_kb.triples[:, 0],
                    etoken_map,
                    maxlen=args.max_seq_length)
                test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(
                    test_kb.triples[:, 1],
                    rtoken_map,
                    maxlen=args.max_seq_length)
                test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(
                    test_kb.triples[:, 2],
                    etoken_map,
                    maxlen=args.max_seq_length)

                # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map)
                indices = torch.Tensor(
                    range(len(test_kb.triples))
                )  #indices would be used to fetch alternative answers while evaluating
                test_data = TensorDataset(indices, test_e1_tokens_tensor,
                                          test_r_tokens_tensor,
                                          test_e2_tokens_tensor,
                                          test_e1_tokens_lengths,
                                          test_r_tokens_lengths,
                                          test_e2_tokens_lengths)
                test_sampler = SequentialSampler(test_data)
                test_dataloader = DataLoader(test_data,
                                             sampler=test_sampler,
                                             batch_size=args.eval_batch_size)
                split_dim_for_eval = 1
                if (args.embedding_dim >= 256 and "olpbench" in args.data_dir
                        and "rotat" in args.model):
                    split_dim_for_eval = 4
                if (args.embedding_dim >= 512 and "olpbench" in args.data_dir):
                    split_dim_for_eval = 4
                if (args.embedding_dim >= 512 and "olpbench" in args.data_dir
                        and "rotat" in args.model):
                    split_dim_for_eval = 6
                split_dim_for_eval = 1
                for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in test_dataloader:
                    test_e1_tokens, test_e1_lengths = test_e1_tokens.to(
                        device), test_e1_lengths.to(device)
                    test_r_tokens, test_r_lengths = test_r_tokens.to(
                        device), test_r_lengths.to(device)
                    test_e2_tokens, test_e2_lengths = test_e2_tokens.to(
                        device), test_e2_lengths.to(device)
                    with torch.no_grad():
                        e1_real_lstm, e1_img_lstm = model.get_mention_embedding(
                            test_e1_tokens, 0, test_e1_lengths)
                        r_real_lstm, r_img_lstm = model.get_mention_embedding(
                            test_r_tokens, 1, test_r_lengths)
                        e2_real_lstm, e2_img_lstm = model.get_mention_embedding(
                            test_e2_tokens, 0, test_e2_lengths)

                    for count in range(index.shape[0]):
                        # breakpoint()
                        this_e1_real = e1_real_lstm[count].unsqueeze(0)
                        this_e1_img = e1_img_lstm[count].unsqueeze(0)
                        this_r_real = r_real_lstm[count].unsqueeze(0)
                        this_r_img = r_img_lstm[count].unsqueeze(0)
                        this_e2_real = e2_real_lstm[count].unsqueeze(0)
                        this_e2_img = e2_img_lstm[count].unsqueeze(0)
                        simi_t = model.complex_score_e1_r_with_all_ementions(
                            this_e1_real,
                            this_e1_img,
                            this_r_real,
                            this_r_img,
                            ementions_real,
                            ementions_img,
                            split=split_dim_for_eval).squeeze(0)
                        simi_h = model.complex_score_e2_r_with_all_ementions(
                            this_e2_real,
                            this_e2_img,
                            this_r_real,
                            this_r_img,
                            ementions_real,
                            ementions_img,
                            split=split_dim_for_eval).squeeze(0)
                        # get known answers for filtered ranking
                        ind = index[count]
                        this_correct_mentions_e2 = test_kb.e2_all_answers[int(
                            ind.item())]
                        this_correct_mentions_e1 = test_kb.e1_all_answers[int(
                            ind.item())]

                        all_correct_mentions_e2 = all_known_e2.get(
                            (em_map[test_kb.triples[int(ind.item())][0]],
                             rm_map[test_kb.triples[int(ind.item())][1]]), [])
                        all_correct_mentions_e1 = all_known_e1.get(
                            (em_map[test_kb.triples[int(ind.item())][2]],
                             rm_map[test_kb.triples[int(ind.item())][1]]), [])

                        # compute metrics
                        best_score = simi_t[this_correct_mentions_e2].max()
                        simi_t[
                            all_correct_mentions_e2] = -20000000  # MOST NEGATIVE VALUE
                        greatereq = simi_t.ge(best_score).float()
                        equal = simi_t.eq(best_score).float()
                        rank = greatereq.sum() + 1 + equal.sum() / 2.0

                        metrics['mr_t'] += rank
                        metrics['mrr_t'] += 1.0 / rank
                        metrics['hits1_t'] += rank.le(1).float()
                        metrics['hits10_t'] += rank.le(10).float()
                        metrics['hits50_t'] += rank.le(50).float()

                        best_score = simi_h[this_correct_mentions_e1].max()
                        simi_h[
                            all_correct_mentions_e1] = -20000000  # MOST NEGATIVE VALUE
                        greatereq = simi_h.ge(best_score).float()
                        equal = simi_h.eq(best_score).float()
                        rank = greatereq.sum() + 1 + equal.sum() / 2.0
                        metrics['mr_h'] += rank
                        metrics['mrr_h'] += 1.0 / rank
                        metrics['hits1_h'] += rank.le(1).float()
                        metrics['hits10_h'] += rank.le(10).float()
                        metrics['hits50_h'] += rank.le(50).float()

                        metrics['mr'] = (metrics['mr_h'] + metrics['mr_t']) / 2
                        metrics['mrr'] = (metrics['mrr_h'] +
                                          metrics['mrr_t']) / 2
                        metrics['hits1'] = (metrics['hits1_h'] +
                                            metrics['hits1_t']) / 2
                        metrics['hits10'] = (metrics['hits10_h'] +
                                             metrics['hits10_t']) / 2
                        metrics['hits50'] = (metrics['hits50_h'] +
                                             metrics['hits50_t']) / 2

                for key in metrics:
                    metrics[key] = metrics[key] / len(test_kb.triples)
                if best_metrics == None or best_metrics['hits1'] < metrics[
                        'hits1']:
                    best_model = model_path
                    best_metrics = metrics
                print("best_hits1:", best_metrics['hits1'])
            except:
                continue
        print(best_metrics)
        print(best_model)
示例#8
0
 def _read(self, labels_fp):
     labels_list, _ = utils.read_mentions(labels_fp)
     for mcat_index, label in enumerate(labels_list):
         yield self.text_to_instance(label, mcat_index)
示例#9
0
def main(args):
	hits_1_triple = []
	hits_1_correct_answers = []
	hits_1_model_top10 = []

	nothits_50_triple = []
	nothits_50_correct_answers = []
	nothits_50_model_top10 = []


	# read token maps
	etokens, etoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","entity_token_id_map.txt"))
	rtokens, rtoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","relation_token_id_map.txt"))
	entity_mentions,em_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","entity_id_map.txt"))
	_,rm_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","relation_id_map.txt"))

	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	random.seed(args.seed)
	np.random.seed(args.seed)
	torch.manual_seed(args.seed)

	#train code (+1 for unk token)
	model = complexLSTM(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens)

	if(args.resume):
		print("Resuming from:",args.resume)
		checkpoint = torch.load(args.resume,map_location=lambda storage, loc: storage)
		model.load_state_dict(checkpoint['state_dict'])

	model.eval()

	# get embeddings for all entity mentions
	entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(entity_mentions,etoken_map,maxlen=args.max_seq_length,use_tqdm=True)
	entity_mentions_tensor = entity_mentions_tensor.cuda()
	entity_mentions_lengths = entity_mentions_lengths.cuda()

	ementions_real_lis = []
	ementions_img_lis = []
	split = 100 #cant fit all in gpu together. hence split
	with torch.no_grad():
		for i in tqdm(range(0,len(entity_mentions_tensor),len(entity_mentions_tensor)//split)):
			data = entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]
			data_lengths = entity_mentions_lengths[i:i+len(entity_mentions_tensor)//split]
			ementions_real_lstm,ementions_img_lstm = model.get_mention_embedding(data,0,data_lengths)			

			ementions_real_lis.append(ementions_real_lstm.cpu())
			ementions_img_lis.append(ementions_img_lstm.cpu())
	del entity_mentions_tensor,ementions_real_lstm,ementions_img_lstm
	torch.cuda.empty_cache()
	ementions_real = torch.cat(ementions_real_lis).cuda()
	ementions_img = torch.cat(ementions_img_lis).cuda()
	########################################################################

	if "olpbench" in args.data_dir:
		# test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map)
		test_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map)
	else:
		test_kb = kb(os.path.join(args.data_dir,"test.txt"), em_map = em_map, rm_map = rm_map)
	print("Loading all_known pickled data...(takes times since large)")
	all_known_e2 = {}
	all_known_e1 = {}
	all_known_e2,all_known_e1 = pickle.load(open(os.path.join(args.data_dir,"all_knowns_simple_linked.pkl"),"rb"))


	test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(test_kb.triples[:,0], etoken_map,maxlen=args.max_seq_length)
	test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(test_kb.triples[:,1], rtoken_map,maxlen=args.max_seq_length)
	test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(test_kb.triples[:,2], etoken_map,maxlen=args.max_seq_length)
	
	# e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map)
	indices = torch.Tensor(range(len(test_kb.triples))) #indices would be used to fetch alternative answers while evaluating
	test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths)
	test_sampler = SequentialSampler(test_data)
	test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size)
	split_dim_for_eval = 1
	if(args.embedding_dim>=512 and "olpbench" in args.data_dir):
		split_dim_for_eval = 4
	for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(test_dataloader,desc="Test dataloader"):
		test_e1_tokens, test_e1_lengths = test_e1_tokens.to(device), test_e1_lengths.to(device)
		test_r_tokens, test_r_lengths = test_r_tokens.to(device), test_r_lengths.to(device)
		test_e2_tokens, test_e2_lengths = test_e2_tokens.to(device), test_e2_lengths.to(device)
		with torch.no_grad():
			e1_real_lstm, e1_img_lstm = model.get_mention_embedding(test_e1_tokens,0, test_e1_lengths)
			r_real_lstm, r_img_lstm = model.get_mention_embedding(test_r_tokens,1, test_r_lengths)	
			e2_real_lstm, e2_img_lstm = model.get_mention_embedding(test_e2_tokens,0, test_e2_lengths)


		for count in tqdm(range(index.shape[0]), desc="Evaluating"):
			this_e1_real = e1_real_lstm[count].unsqueeze(0)
			this_e1_img  = e1_img_lstm[count].unsqueeze(0)
			this_r_real  = r_real_lstm[count].unsqueeze(0)
			this_r_img   = r_img_lstm[count].unsqueeze(0)
			this_e2_real = e2_real_lstm[count].unsqueeze(0)
			this_e2_img  = e2_img_lstm[count].unsqueeze(0)
			
			# get known answers for filtered ranking
			ind = index[count]
			this_correct_mentions_e2 = test_kb.e2_all_answers[int(ind.item())]
			this_correct_mentions_e1 = test_kb.e1_all_answers[int(ind.item())] 

			all_correct_mentions_e2 = all_known_e2.get((em_map[test_kb.triples[int(ind.item())][0]],rm_map[test_kb.triples[int(ind.item())][1]]),[])
			all_correct_mentions_e1 = all_known_e1.get((em_map[test_kb.triples[int(ind.item())][2]],rm_map[test_kb.triples[int(ind.item())][1]]),[])
			if(args.head_or_tail=="tail"):
				simi = model.complex_score_e1_r_with_all_ementions(this_e1_real,this_e1_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0)
				best_score = simi[this_correct_mentions_e2].max()
				simi[all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE
				greatereq = simi.ge(best_score).float()
				equal = simi.eq(best_score).float()
				rank = greatereq.sum()+1+equal.sum()/2.0

			else:
				simi = model.complex_score_e2_r_with_all_ementions(this_e2_real,this_e2_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0)
				best_score = simi[this_correct_mentions_e1].max()
				simi[all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE
				greatereq = simi.ge(best_score).float()
				equal = simi.eq(best_score).float()
				rank = greatereq.sum()+1+equal.sum()/2.0

			if(rank<=1):
				#hits1
				hits_1_triple.append([test_kb.triples[int(ind.item())][0],test_kb.triples[int(ind.item())][1],test_kb.triples[int(ind.item())][2]])
				if(args.head_or_tail=="tail"):
					# hits_1_correct_answers.append(this_correct_mentions_e2)
					hits_1_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e2])
				else:
					hits_1_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e1])
				hits_1_model_top10.append([])
			elif(rank>50):
				#nothits50
				nothits_50_triple.append([test_kb.triples[int(ind.item())][0],test_kb.triples[int(ind.item())][1],test_kb.triples[int(ind.item())][2]])
				if(args.head_or_tail=="tail"):
					nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e2])
				else:
					nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e1])
				tmp = simi.sort()[1].tolist()[::-1][:10]
				nothits_50_model_top10.append([entity_mentions[x] for x in tmp])
	
	indices = list(range(len(hits_1_triple)))
	random.shuffle(indices)
	indices = indices[:args.sample]
	for ind in indices:
		print(ind,"|",hits_1_triple[ind],"|",hits_1_correct_answers[ind],"|",hits_1_model_top10[ind])
	print("---------------------------------------------------------------------------------------------")
	indices = list(range(len(nothits_50_triple)))
	random.shuffle(indices)
	indices = indices[:args.sample]
	for ind in indices:
		print(ind,"|",nothits_50_triple[ind],"|",nothits_50_correct_answers[ind],"|",nothits_50_model_top10[ind])
def main():
    K = 1000000000000000000
    data_dir = "data/olpbench"
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(data_dir, "mapped_to_ids", "entity_id_map.txt"))
    relation_mentions, rm_map = utils.read_mentions(
        os.path.join(data_dir, "mapped_to_ids", "relation_id_map.txt"))
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    NPROCS = 50
    LOCAL_SIZE = 10000 // NPROCS
    # test_kb = kb(os.path.join(data_dir,"validation_data_linked.txt"), em_map = em_map, rm_map = rm_map)
    test_kb = kb(os.path.join(data_dir, "test_data.txt"),
                 em_map=em_map,
                 rm_map=rm_map)
    if (1):
        graph = pickle.load(
            open("helper_scripts/tmp/graph_thorough_no_r_count_paths.pkl",
                 'rb'))
        if (0):
            train_kb = kb(os.path.join(data_dir, "train_data_thorough.txt"),
                          em_map=None,
                          rm_map=None)
            graph = {}

            f = open("helper_scripts/tmp/graph_thorough_no_r_count_paths.pkl",
                     'wb')
            for triple in tqdm(train_kb.triples):
                e1 = em_map[triple[0].item()]
                r = rm_map[triple[1].item()]
                e2 = em_map[triple[2].item()]

                if e1 not in graph:
                    graph[e1] = defaultdict(int)
                graph[e1][e2] += 1

                if e2 not in graph:
                    graph[e2] = defaultdict(int)
                graph[e2][e1] += 1

            pickle.dump(graph, f)
            f.close()
            exit()
        manager = Manager()
        cache_e = manager.dict()
        test_numerator = manager.dict()
        test_denominator = manager.dict()
        test_sub_lines = manager.dict()
        for i in range(NPROCS):
            test_sub_lines[i] = manager.list()
        process_array = []
        for i in range(NPROCS):
            os.system("mkdir helper_scripts/xt_tmp_folder/" + str(i))
            p = mp.Process(target=func,
                           args=(test_numerator, test_denominator,
                                 test_sub_lines, cache_e, test_kb, i,
                                 LOCAL_SIZE, em_map, rm_map, graph, K))
            p.start()
            process_array.append(p)
        for p in process_array:
            p.join()

        num = 0
        den = 0
        final_test_lines = []
        for key in test_numerator:
            num += test_numerator[key]
            den += test_denominator[key]
            final_test_lines.extend(test_sub_lines[key])
        print(num / den)
        print(num / 10000)
        f = open(
            "helper_scripts/tmp/test_data-full_neighbors-subset_2hop.txt.tail.xt",
            'w')
        for line in final_test_lines:
            f.write(line)
        f.close()
def main():
    K = 2
    data_dir = "data/olpbench"
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(data_dir, "mapped_to_ids", "entity_id_map.txt"))
    relation_mentions, rm_map = utils.read_mentions(
        os.path.join(data_dir, "mapped_to_ids", "relation_id_map.txt"))

    graph = pickle.load(
        open("helper_scripts/tmp/graph_thorough_no_r.pkl", 'rb'))
    if (0):
        train_kb = kb(os.path.join(data_dir, "train_data_thorough.txt"),
                      em_map=None,
                      rm_map=None)
        # train_kb = kb(os.path.join(data_dir,"test_data.txt"), em_map = None, rm_map = None)
        # train_kb = kb(os.path.join(data_dir,"validation_data_linked.txt"), em_map = None, rm_map = None)

        graph = {}
        graph_forward = {}
        graph_backward = {}
        rels_for_e1_e2 = {}

        f = open("helper_scripts/tmp/graph_thorough_no_r.pkl", 'wb')
        for triple in tqdm(train_kb.triples):
            e1 = em_map[triple[0].item()]
            r = rm_map[triple[1].item()]
            e2 = em_map[triple[2].item()]

            if e1 not in graph:
                graph[e1] = []
            graph[e1].append(e2)

            if e2 not in graph:
                graph[e2] = []
            graph[e2].append(e1)

        pickle.dump(graph, f)
        f.close()
        exit()
        # ----------------------------------------------------
        # if (e1,e2) not in rels_for_e1_e2:
        # 	rels_for_e1_e2[(e1,e2)] = []
        # rels_for_e1_e2[(e1,e2)].append(r)

        # if e1 not in graph_forward:
        # 	graph_forward[e1] = set()
        # # graph_forward[e1].add((e2,r))
        # graph_forward[e1].add(e2)

        # if e2 not in graph_backward:
        # 	graph_backward[e2] = set()
        # # graph_backward[e2].add((e1,r))
        # graph_backward[e2].add(e1)

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    test_kb = kb(os.path.join(data_dir, "test_data.txt"),
                 em_map=None,
                 rm_map=None)
    keshav_xt_lines = open("helper_scripts/tmp/test_data.txt.tail.xt",
                           'r').readlines()
    f = open("helper_scripts/tmp/test_data.txt.tail.xt.all-2-hop-neighbors",
             'w')

    for i in range(len(keshav_xt_lines)):
        keshav_xt_lines[i] = keshav_xt_lines[i][keshav_xt_lines[i].index(" ") +
                                                1:].strip()

    two_hop_data = []
    count = 0
    for ind, triple in tqdm(enumerate(test_kb.triples), desc="test triples"):
        e1 = em_map[triple[0].item()]
        r = rm_map[triple[1].item()]
        e2 = em_map[triple[2].item()]
        e1_neighbours = bfs(e1, K, graph)[:10000]
        if e2 in e1_neighbours:
            count += 1
        neighbour_string = ""
        for neighbor in e1_neighbours:
            neighbour_string += "__label__" + str(neighbor) + " "
        to_write = neighbour_string + keshav_xt_lines[ind]
        f.write(to_write + "\n")
        f.flush()
        # print(to_write,file = f)
    print(count)
    f.close()
示例#12
0
def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # read token maps
    etokens, etoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "entity_token_id_map.txt"))
    rtokens, rtoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "relation_token_id_map.txt"))
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt"))
    relation_mentions, rm_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt"))

    # create entity_token_indices and entity_lengths
    # [[max length indices for entity 0 ], [max length indices for entity 1], [max length indices for entity 2], ...]
    # [length of entity 0, length of entity 1, length of entity 2, ...]
    entity_token_indices, entity_lengths = utils.get_token_indices_from_mention_indices(
        entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True)
    relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(
        relation_mentions,
        rtoken_map,
        maxlen=args.max_seq_length,
        use_tqdm=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    #train code (+1 for unk token)
    model = complexLSTM_2_all_e_PLT(
        len(etoken_map) + 1,
        len(rtoken_map) + 1,
        len(entity_mentions),
        args.embedding_dim, {},
        initial_token_embedding=args.initial_token_embedding,
        entity_tokens=etokens,
        relation_tokens=rtokens,
        lstm_dropout=args.lstm_dropout)
    logsigmoid = torch.nn.LogSigmoid()
    if (args.do_train):
        data_config = {
            'input_file': 'train_data_thorough.txt',
            'batch_size': args.train_batch_size,
            'use_batch_shared_entities': True,
            'min_size_batch_labels': args.train_batch_size,
            'max_size_prefix_label': 64,
            'device': 0
        }
        expt_settings = {
            'loss': 'bce',
            'replace_entities_by_tokens': True,
            'replace_relations_by_tokens': True,
            'max_lengths_tuple': [10, 10]
        }
        train_data = OneToNMentionRelationDataset(dataset_dir=os.path.join(
            args.data_dir, "mapped_to_ids"),
                                                  is_training_data=True,
                                                  **data_config,
                                                  **expt_settings)
        train_data.create_data_tensors(
            dataset_dir=os.path.join(args.data_dir, "mapped_to_ids"),
            train_input_file='train_data_thorough.txt',
            valid_input_file='validation_data_linked.txt',
            test_input_file='test_data.txt',
        )
        train_loader = train_data.get_loader(
            shuffle=True,
            num_workers=8,
            drop_last=True,
        )
        # optimizer = torch.optim.Adagrad(model.parameters(),lr=args.learning_rate,weight_decay=args.weight_decay)
        optimizer = torch.optim.Adagrad(model.parameters(),
                                        lr=args.learning_rate)

        if (args.resume):
            print("Resuming from:", args.resume)
            # checkpoint = torch.load(args.resume,map_location=lambda storage, loc: storage)
            checkpoint = torch.load(args.resume, map_location="cpu")
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            del checkpoint
            torch.cuda.empty_cache()

            #Load other things too if required

        model.train()

        # crossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='mean')
        BCEloss = torch.nn.BCEWithLogitsLoss(reduction='sum')
        for epoch in tqdm(range(0, args.num_train_epochs), desc="epoch"):
            iteration = 0
            # ct = 0
            for batch in tqdm(train_loader, desc="Train dataloader"):
                # ct+=1
                # if (ct==1000):
                # 	break
                inputs, \
                normalizer_loss, \
                normalizer_metric, \
                labels, \
                label_ids, \
                filter_mask, \
                batch_shared_entities = train_data.input_and_labels_to_device(
                 batch,
                 training=True,
                 device="cpu"
                )
                if inputs[0] == None:
                    num_samples_for_head = 0
                else:
                    num_samples_for_head = inputs[0][0].shape[0]
                tree_nodes_head, head_mask, tree_nodes_tail, tail_mask,\
                tree_nodes_head_neg, head_mask_neg, tree_nodes_tail_neg, tail_mask_neg = model.get_tree_nodes(batch_shared_entities - 2, labels, num_samples_for_head)

                loss = torch.tensor(0, device=device)
                for mode, model_inputs in zip(["head", "tail"], inputs):
                    if model_inputs == None:
                        continue
                    # subtract two from author's indices because our map is 2 less
                    if mode == "head":
                        batch_e2_indices = model_inputs[1] - 2
                        batch_e2_indices = batch_e2_indices.where(
                            batch_e2_indices != -2,
                            torch.tensor(len(entity_mentions) - 2,
                                         dtype=torch.int32))
                        batch_r_indices = model_inputs[0] - 2
                        batch_r_indices = batch_r_indices.where(
                            batch_r_indices != -2,
                            torch.tensor(len(relation_mentions) - 2,
                                         dtype=torch.int32))
                        batch_e1_indices = batch_shared_entities - 2

                        train_r_mention_tensor, train_r_lengths = convert_mention_to_token_indices(
                            batch_r_indices.squeeze(1), relation_token_indices,
                            relation_lengths)
                        train_e2_mention_tensor, train_e2_lengths = convert_mention_to_token_indices(
                            batch_e2_indices.squeeze(1), entity_token_indices,
                            entity_lengths)

                        train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda(
                        ), train_r_lengths.cuda()
                        train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(
                        ), train_e2_lengths.cuda()

                        # e1_real_lstm, e1_img_lstm = model.get_atomic_entity_embeddings(batch_e1_indices.squeeze(1).long().cuda())
                        r_real_lstm, r_img_lstm = model.get_mention_embedding(
                            train_r_mention_tensor, 1, train_r_lengths)
                        e2_real_lstm, e2_img_lstm = model.get_mention_embedding(
                            train_e2_mention_tensor, 0, train_e2_lengths)

                        # import pdb
                        # pdb.set_trace()
                        tmp_nodes_0 = torch.cat(
                            [tree_nodes_head[0], tree_nodes_head_neg[0]],
                            dim=1)
                        tmp_nodes_1 = torch.cat(
                            [tree_nodes_head[1], tree_nodes_head_neg[1]],
                            dim=1)
                        head_mask_neg *= -1
                        tmp_mask = torch.cat([head_mask, head_mask_neg], dim=1)

                        model_output = model.complex_score_e2_r_with_given_ementions(
                            e2_real_lstm, e2_img_lstm, r_real_lstm, r_img_lstm,
                            tmp_nodes_0, tmp_nodes_1)
                        loss = loss - (logsigmoid(
                            model_output * tmp_mask)).mean()
                        # neg
                        # model_output = model.complex_score_e2_r_with_given_ementions(e2_real_lstm,e2_img_lstm,r_real_lstm,r_img_lstm,tree_nodes_head_neg[0],tree_nodes_head_neg[1])
                        # loss = loss - (logsigmoid(-1*model_output*head_mask_neg)).mean()
                    else:
                        batch_e1_indices = model_inputs[0] - 2
                        batch_e1_indices = batch_e1_indices.where(
                            batch_e1_indices != -2,
                            torch.tensor(len(entity_mentions) - 2,
                                         dtype=torch.int32))
                        batch_r_indices = model_inputs[1] - 2
                        batch_r_indices = batch_r_indices.where(
                            batch_r_indices != -2,
                            torch.tensor(len(relation_mentions) - 2,
                                         dtype=torch.int32))
                        batch_e2_indices = batch_shared_entities - 2

                        train_e1_mention_tensor, train_e1_lengths = convert_mention_to_token_indices(
                            batch_e1_indices.squeeze(1), entity_token_indices,
                            entity_lengths)
                        train_r_mention_tensor, train_r_lengths = convert_mention_to_token_indices(
                            batch_r_indices.squeeze(1), relation_token_indices,
                            relation_lengths)

                        train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(
                        ), train_e1_lengths.cuda()
                        train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda(
                        ), train_r_lengths.cuda()

                        e1_real_lstm, e1_img_lstm = model.get_mention_embedding(
                            train_e1_mention_tensor, 0, train_e1_lengths)
                        r_real_lstm, r_img_lstm = model.get_mention_embedding(
                            train_r_mention_tensor, 1, train_r_lengths)

                        tmp_nodes_0 = torch.cat(
                            [tree_nodes_tail[0], tree_nodes_tail_neg[0]],
                            dim=1)
                        tmp_nodes_1 = torch.cat(
                            [tree_nodes_tail[1], tree_nodes_tail_neg[1]],
                            dim=1)
                        tail_mask_neg *= -1
                        tmp_mask = torch.cat([tail_mask, tail_mask_neg], dim=1)

                        model_output = model.complex_score_e1_r_with_given_ementions(
                            e1_real_lstm, e1_img_lstm, r_real_lstm, r_img_lstm,
                            tmp_nodes_0, tmp_nodes_1)
                        loss = loss - (logsigmoid(
                            model_output * tmp_mask)).mean()
                        #neg
                        # model_output = model.complex_score_e1_r_with_given_ementions(e1_real_lstm,e1_img_lstm,r_real_lstm,r_img_lstm,tree_nodes_tail_neg[0],tree_nodes_tail_neg[1])
                        # loss = loss - (logsigmoid(-1*model_output*tail_mask_neg)).mean()

                    # all_outputs.append(output)

                # all_outputs = torch.cat(all_outputs)
                # loss = BCEloss(all_outputs.view(-1),labels.view(-1))
                # loss /= normalizer_loss
                # import pdb
                # pdb.set_trace()
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                if (iteration % args.print_loss_every == 0):
                    print("Current loss:", loss.item())
                iteration += 1
            if (epoch % args.save_model_every == 0):
                utils.save_checkpoint(
                    {
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    args.output_dir + "/checkpoint_epoch_{}".format(epoch + 1))

    if args.do_eval:
        #eval code
        metrics = {}
        metrics['mr'] = 0
        metrics['mrr'] = 0
        metrics['hits1'] = 0
        metrics['hits10'] = 0
        metrics['hits50'] = 0
        metrics['mr_t'] = 0
        metrics['mrr_t'] = 0
        metrics['hits1_t'] = 0
        metrics['hits10_t'] = 0
        metrics['hits50_t'] = 0
        metrics['mr_h'] = 0
        metrics['mrr_h'] = 0
        metrics['hits1_h'] = 0
        metrics['hits10_h'] = 0
        metrics['hits50_h'] = 0

        if (args.resume and not args.do_train):
            print("Resuming from:", args.resume)
            checkpoint = torch.load(args.resume,
                                    map_location=lambda storage, loc: storage)
            model.load_state_dict(checkpoint['state_dict'])

        model.eval()

        cut_nodes_indices = []
        cut_mask = []
        for i in model.hsoftmax.nodes_at_cut:
            cut_nodes_indices.append(model.hsoftmax.node_indices_for_t[i])
            cut_mask.append(model.hsoftmax.mask_for_t[i])
        cut_mask = torch.tensor(cut_mask, device=device)
        cut_nodes_indices = torch.tensor(cut_nodes_indices, device=device)
        cut_nodes_real, cut_nodes_img = model.E_atomic(
            cut_nodes_indices).chunk(2, dim=-1)

        # all_nodes_indices = torch.tensor(range(len(entity_mentions)), device=device).unsqueeze(0)
        # import pdb
        # pdb.set_trace()
        # I checked that it can save at max 5 * len(entity_mentions)
        # ementions_real, ementions_img = model.E_atomic(tree_nodes_indices).chunk(2,dim=-1)
        # all_nodes_real, all_nodes_img = model.E_atomic(all_nodes_indices).chunk(2,dim=-1)

        ########################################################################
        if "olpbench" in args.data_dir:
            # test_kb = kb(os.path.join(args.data_dir,"train_data_thorough.txt"), em_map = em_map, rm_map = rm_map)
            test_kb = kb(os.path.join(args.data_dir, "test_data.txt"),
                         em_map=em_map,
                         rm_map=rm_map)
        else:
            test_kb = kb(os.path.join(args.data_dir, "test.txt"),
                         em_map=em_map,
                         rm_map=rm_map)

        print("Loading all_known pickled data...(takes times since large)")
        all_known_e2 = {}
        all_known_e1 = {}
        all_known_e2, all_known_e1 = pickle.load(
            open(
                os.path.join(
                    args.data_dir,
                    "all_knowns_{}_linked.pkl".format(args.train_data_type)),
                "rb"))

        test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(
            test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length)
        test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(
            test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length)
        test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(
            test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length)

        # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map)
        indices = torch.Tensor(
            range(len(test_kb.triples))
        )  #indices would be used to fetch alternative answers while evaluating
        test_data = TensorDataset(indices, test_e1_tokens_tensor,
                                  test_r_tokens_tensor, test_e2_tokens_tensor,
                                  test_e1_tokens_lengths,
                                  test_r_tokens_lengths,
                                  test_e2_tokens_lengths)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.eval_batch_size)
        split_dim_for_eval = 1
        # e1_support = []
        # e2_support = []

        for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(
                test_dataloader, desc="Test dataloader"):
            print(metrics)
            test_e1_tokens, test_e1_lengths = test_e1_tokens.to(
                device), test_e1_lengths.to(device)
            test_r_tokens, test_r_lengths = test_r_tokens.to(
                device), test_r_lengths.to(device)
            test_e2_tokens, test_e2_lengths = test_e2_tokens.to(
                device), test_e2_lengths.to(device)
            with torch.no_grad():
                e1_real_lstm, e1_img_lstm = model.get_mention_embedding(
                    test_e1_tokens, 0, test_e1_lengths)
                r_real_lstm, r_img_lstm = model.get_mention_embedding(
                    test_r_tokens, 1, test_r_lengths)
                e2_real_lstm, e2_img_lstm = model.get_mention_embedding(
                    test_e2_tokens, 0, test_e2_lengths)

            for count in tqdm(range(index.shape[0]), desc="Evaluating"):
                # breakpoint()
                this_e1_real = e1_real_lstm[count]
                this_e1_img = e1_img_lstm[count]
                this_r_real = r_real_lstm[count]
                this_r_img = r_img_lstm[count]
                this_e2_real = e2_real_lstm[count]
                this_e2_img = e2_img_lstm[count]

                ind = index[count]
                this_correct_mentions_e2 = test_kb.e2_all_answers[int(
                    ind.item())]
                this_correct_mentions_e1 = test_kb.e1_all_answers[int(
                    ind.item())]

                all_correct_mentions_e2 = all_known_e2.get(
                    (em_map[test_kb.triples[int(ind.item())][0]],
                     rm_map[test_kb.triples[int(ind.item())][1]]), [])
                all_correct_mentions_e1 = all_known_e1.get(
                    (em_map[test_kb.triples[int(ind.item())][2]],
                     rm_map[test_kb.triples[int(ind.item())][1]]), [])

                with torch.no_grad():
                    pass
                    simi_t = model.test_query(this_e1_real, this_e1_img,
                                              this_r_real, this_r_img, None,
                                              None, "tail", cut_mask,
                                              cut_nodes_real, cut_nodes_img)
                    simi_h = model.test_query(None, None, this_r_real,
                                              this_r_img, this_e2_real,
                                              this_e2_real, "head", cut_mask,
                                              cut_nodes_real, cut_nodes_img)
                    # simi_t = model.test_query_debug(this_e1_real, this_e1_img, this_r_real, this_r_img, None, None, "tail", cut_mask, cut_nodes_real, cut_nodes_img,this_correct_mentions_e2)
                    # simi_h = model.test_query_debug(None, None, this_r_real, this_r_img, this_e2_real, this_e2_real, "head", cut_mask, cut_nodes_real, cut_nodes_img,this_correct_mentions_e1)

                # get known answers for filtered ranking
                # e1_support.append(len(all_correct_mentions_e1))
                # e2_support.append(len(all_correct_mentions_e2))

                # compute metrics
                # for mention in all_correct_mentions_e2:
                # 	if mention in simi_t:
                # 		rank = torch.tensor(1.).cuda()
                # 		break
                # else:
                # 	rank = torch.tensor(2.).cuda()

                # rank = simi_t

                best_score = simi_t[this_correct_mentions_e2].max()
                simi_t[
                    all_correct_mentions_e2] = -20000000  # MOST NEGATIVE VALUE
                greatereq = simi_t.ge(best_score).float()
                equal = simi_t.eq(best_score).float()
                rank = greatereq.sum() + 1 + equal.sum() / 2.0

                metrics['mr_t'] += rank
                metrics['mrr_t'] += 1.0 / rank
                metrics['hits1_t'] += rank.le(1).float()
                metrics['hits10_t'] += rank.le(10).float()
                metrics['hits50_t'] += rank.le(50).float()

                # for mention in all_correct_mentions_e1:
                # 	if mention in simi_h:
                # 		rank = torch.tensor(1.).cuda()
                # 		break
                # else:
                # 	rank = torch.tensor(2.).cuda()

                # rank = simi_h

                best_score = simi_h[this_correct_mentions_e1].max()
                simi_h[
                    all_correct_mentions_e1] = -20000000  # MOST NEGATIVE VALUE
                greatereq = simi_h.ge(best_score).float()
                equal = simi_h.eq(best_score).float()
                rank = greatereq.sum() + 1 + equal.sum() / 2.0

                metrics['mr_h'] += rank
                metrics['mrr_h'] += 1.0 / rank
                metrics['hits1_h'] += rank.le(1).float()
                metrics['hits10_h'] += rank.le(10).float()
                metrics['hits50_h'] += rank.le(50).float()

                metrics['mr'] = (metrics['mr_h'] + metrics['mr_t']) / 2
                metrics['mrr'] = (metrics['mrr_h'] + metrics['mrr_t']) / 2
                metrics['hits1'] = (metrics['hits1_h'] +
                                    metrics['hits1_t']) / 2
                metrics['hits10'] = (metrics['hits10_h'] +
                                     metrics['hits10_t']) / 2
                metrics['hits50'] = (metrics['hits50_h'] +
                                     metrics['hits50_t']) / 2
        # e1_support = torch.tensor(e1_support)
        # e2_support = torch.tensor(e2_support)

        # import pdb
        # pdb.set_trace()
        for key in metrics:
            metrics[key] = metrics[key] / len(test_kb.triples)
        print(metrics)
def main(args):
    hits_1_triple = []
    hits_1_correct_answers = []
    hits_1_model_top10 = []
    hits_1_evidence = []
    baseline_tail_hits1_indices = set([
        36, 91, 95, 101, 119, 158, 282, 397, 638, 728, 740, 763, 914, 959, 972,
        992, 1184, 1478, 1669, 1686, 1732, 1795, 1796, 1822, 1826, 1845, 1924,
        1939, 1943, 2055, 2178, 2317, 2319, 2325, 2482, 2513, 2589, 2627, 2674,
        2736, 2862, 2985, 3049, 3311, 3327, 3491, 3660, 3728, 3817, 3818, 4111,
        4263, 4387, 4437, 4438, 4452, 4525, 4591, 4670, 4856, 5114, 5159, 5318,
        5587, 5851, 5857, 5893, 5925, 5942, 5990, 6056, 6079, 6119, 6172, 6195,
        6211, 6228, 6262, 6267, 6460, 6491, 6509, 6584, 6676, 6699, 6862, 6982,
        7057, 7078, 7084, 7221, 7597, 7733, 7837, 8045, 8278, 8326, 8380, 8433,
        8453, 8479, 8534, 8540, 8742, 8813, 8860, 8906, 8930, 9234, 9333, 9500,
        9535, 9589, 9663, 9803, 9809, 9866, 9999
    ])
    baseline_correct = 0
    # nothits_50_triple = []
    # nothits_50_correct_answers = []
    # nothits_50_model_top10 = []

    injected_rels = kb(args.evidence_file, em_map=None,
                       rm_map=None).triples[:, 1].reshape(-1, args.n_times)

    # read token maps
    etokens, etoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "entity_token_id_map.txt"))
    rtokens, rtoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "relation_token_id_map.txt"))
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt"))
    _, rm_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt"))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    #train code (+1 for unk token)
    model = complexLSTM(len(etoken_map) + 1,
                        len(rtoken_map) + 1,
                        args.embedding_dim,
                        initial_token_embedding=args.initial_token_embedding,
                        entity_tokens=etokens,
                        relation_tokens=rtokens,
                        lstm_dropout=0)

    if (args.resume):
        print("Resuming from:", args.resume)
        checkpoint = torch.load(args.resume,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])

    model.eval()

    # get embeddings for all entity mentions
    entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(
        entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True)
    entity_mentions_tensor = entity_mentions_tensor.cuda()
    entity_mentions_lengths = entity_mentions_lengths.cuda()

    ementions_real_lis = []
    ementions_img_lis = []
    split = 100  #cant fit all in gpu together. hence split
    with torch.no_grad():
        for i in tqdm(
                range(0, len(entity_mentions_tensor),
                      len(entity_mentions_tensor) // split)):
            data = entity_mentions_tensor[i:i + len(entity_mentions_tensor) //
                                          split, :]
            data_lengths = entity_mentions_lengths[
                i:i + len(entity_mentions_tensor) // split]
            ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding(
                data, 0, data_lengths)

            ementions_real_lis.append(ementions_real_lstm.cpu())
            ementions_img_lis.append(ementions_img_lstm.cpu())
    del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm
    torch.cuda.empty_cache()
    ementions_real = torch.cat(ementions_real_lis).cuda()
    ementions_img = torch.cat(ementions_img_lis).cuda()
    ########################################################################

    if "olpbench" in args.data_dir:
        test_kb = kb(os.path.join(args.data_dir, "test_data.txt"),
                     em_map=em_map,
                     rm_map=rm_map)
    else:
        test_kb = kb(os.path.join(args.data_dir, "test.txt"),
                     em_map=em_map,
                     rm_map=rm_map)
    print("Loading all_known pickled data...(takes times since large)")
    all_known_e2 = {}
    all_known_e1 = {}
    all_known_e2, all_known_e1 = pickle.load(
        open(os.path.join(args.data_dir, "all_knowns_thorough_linked.pkl"),
             "rb"))

    test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(
        test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length)
    test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(
        test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length)
    test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(
        test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length)

    # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map)
    indices = torch.Tensor(
        range(len(test_kb.triples))
    )  #indices would be used to fetch alternative answers while evaluating
    test_data = TensorDataset(indices, test_e1_tokens_tensor,
                              test_r_tokens_tensor, test_e2_tokens_tensor,
                              test_e1_tokens_lengths, test_r_tokens_lengths,
                              test_e2_tokens_lengths)
    test_sampler = SequentialSampler(test_data)
    test_dataloader = DataLoader(test_data,
                                 sampler=test_sampler,
                                 batch_size=args.eval_batch_size)
    split_dim_for_eval = 1
    if (args.embedding_dim >= 512 and "olpbench" in args.data_dir):
        split_dim_for_eval = 4
    for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(
            test_dataloader, desc="Test dataloader"):
        test_e1_tokens, test_e1_lengths = test_e1_tokens.to(
            device), test_e1_lengths.to(device)
        test_r_tokens, test_r_lengths = test_r_tokens.to(
            device), test_r_lengths.to(device)
        test_e2_tokens, test_e2_lengths = test_e2_tokens.to(
            device), test_e2_lengths.to(device)
        with torch.no_grad():
            e1_real_lstm, e1_img_lstm = model.get_mention_embedding(
                test_e1_tokens, 0, test_e1_lengths)
            r_real_lstm, r_img_lstm = model.get_mention_embedding(
                test_r_tokens, 1, test_r_lengths)
            e2_real_lstm, e2_img_lstm = model.get_mention_embedding(
                test_e2_tokens, 0, test_e2_lengths)

        for count in tqdm(range(index.shape[0]), desc="Evaluating"):
            this_e1_real = e1_real_lstm[count].unsqueeze(0)
            this_e1_img = e1_img_lstm[count].unsqueeze(0)
            this_r_real = r_real_lstm[count].unsqueeze(0)
            this_r_img = r_img_lstm[count].unsqueeze(0)
            this_e2_real = e2_real_lstm[count].unsqueeze(0)
            this_e2_img = e2_img_lstm[count].unsqueeze(0)

            # get known answers for filtered ranking
            ind = index[count]
            this_correct_mentions_e2 = test_kb.e2_all_answers[int(ind.item())]
            this_correct_mentions_e1 = test_kb.e1_all_answers[int(ind.item())]

            all_correct_mentions_e2 = all_known_e2.get(
                (em_map[test_kb.triples[int(ind.item())][0]],
                 rm_map[test_kb.triples[int(ind.item())][1]]), [])
            all_correct_mentions_e1 = all_known_e1.get(
                (em_map[test_kb.triples[int(ind.item())][2]],
                 rm_map[test_kb.triples[int(ind.item())][1]]), [])
            if (args.head_or_tail == "tail"):
                simi = model.complex_score_e1_r_with_all_ementions(
                    this_e1_real,
                    this_e1_img,
                    this_r_real,
                    this_r_img,
                    ementions_real,
                    ementions_img,
                    split=split_dim_for_eval).squeeze(0)
                best_score = simi[this_correct_mentions_e2].max()
                simi[
                    all_correct_mentions_e2] = -20000000  # MOST NEGATIVE VALUE
                greatereq = simi.ge(best_score).float()
                equal = simi.eq(best_score).float()
                rank = greatereq.sum() + 1 + equal.sum() / 2.0

            else:
                simi = model.complex_score_e2_r_with_all_ementions(
                    this_e2_real,
                    this_e2_img,
                    this_r_real,
                    this_r_img,
                    ementions_real,
                    ementions_img,
                    split=split_dim_for_eval).squeeze(0)
                best_score = simi[this_correct_mentions_e1].max()
                simi[
                    all_correct_mentions_e1] = -20000000  # MOST NEGATIVE VALUE
                greatereq = simi.ge(best_score).float()
                equal = simi.eq(best_score).float()
                rank = greatereq.sum() + 1 + equal.sum() / 2.0

            if int(ind.item()) in baseline_tail_hits1_indices:
                if rank <= 1:
                    baseline_correct += 1
                continue
            if (rank <= 1):
                #hits1
                hits_1_triple.append([
                    test_kb.triples[int(ind.item())][0],
                    test_kb.triples[int(ind.item())][1],
                    test_kb.triples[int(ind.item())][2]
                ])
                hits_1_evidence.append(injected_rels[int(ind.item())].tolist())
                if (args.head_or_tail == "tail"):
                    # hits_1_correct_answers.append(this_correct_mentions_e2)
                    hits_1_correct_answers.append(
                        [entity_mentions[x] for x in this_correct_mentions_e2])
                else:
                    hits_1_correct_answers.append(
                        [entity_mentions[x] for x in this_correct_mentions_e1])
                hits_1_model_top10.append([])
            # elif(rank>50):
            # 	#nothits50
            # 	nothits_50_triple.append([test_kb.triples[int(ind.item())][0],test_kb.triples[int(ind.item())][1],test_kb.triples[int(ind.item())][2]])
            # 	if(args.head_or_tail=="tail"):
            # 		nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e2])
            # 	else:
            # 		nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e1])
            # 	tmp = simi.sort()[1].tolist()[::-1][:10]
            # 	nothits_50_model_top10.append([entity_mentions[x] for x in tmp])

    indices = list(range(len(hits_1_triple)))
    random.shuffle(indices)
    indices = indices[:args.sample]
    print(baseline_correct)
    for ind in indices:
        print(ind, "|", hits_1_triple[ind], "|", hits_1_correct_answers[ind],
              "|", hits_1_model_top10[ind], "|", hits_1_evidence[ind])
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--eval_batch_size",
                        default=512,
                        type=int,
                        help="Total batch size for eval.")
    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # read token maps
    etokens, etoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "entity_token_id_map.txt"))
    rtokens, rtoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "relation_token_id_map.txt"))
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt"))
    relation_mentions, rm_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt"))

    if args.model == "complex":
        model = complexLSTM(len(etoken_map) + 1,
                            len(rtoken_map) + 1,
                            args.embedding_dim,
                            initial_token_embedding=None,
                            entity_tokens=etokens,
                            relation_tokens=rtokens,
                            lstm_dropout=args.lstm_dropout)
    elif args.model == "rotate":
        model = rotatELSTM(len(etoken_map) + 1,
                           len(rtoken_map) + 1,
                           args.embedding_dim,
def main(args):
    # read token maps
    etokens, etoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "entity_token_id_map.txt"))
    rtokens, rtoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "relation_token_id_map.txt"))
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt"))
    _, rm_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt"))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    #train code (+1 for unk token)
    if args.model == "complex":
        if args.separate_lstms:
            model = complexLSTM_2(
                len(etoken_map) + 1,
                len(rtoken_map) + 1,
                args.embedding_dim,
                initial_token_embedding=args.initial_token_embedding,
                entity_tokens=etokens,
                relation_tokens=rtokens,
                lstm_dropout=args.lstm_dropout)
        else:
            model = complexLSTM(
                len(etoken_map) + 1,
                len(rtoken_map) + 1,
                args.embedding_dim,
                initial_token_embedding=args.initial_token_embedding,
                entity_tokens=etokens,
                relation_tokens=rtokens,
                lstm_dropout=args.lstm_dropout)
    elif args.model == "rotate":
        model = rotatELSTM(
            len(etoken_map) + 1,
            len(rtoken_map) + 1,
            args.embedding_dim,
            initial_token_embedding=args.initial_token_embedding,
            entity_tokens=etokens,
            relation_tokens=rtokens,
            gamma=args.gamma_rotate,
            lstm_dropout=args.lstm_dropout)
    if (args.do_train):
        optimizer = torch.optim.Adagrad(model.parameters(),
                                        lr=args.learning_rate,
                                        weight_decay=args.weight_decay)

        if (args.resume):
            print("Resuming from:", args.resume)
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            #Load other things too if required

        model.train()
        if "olpbench" in args.data_dir:
            train_kb = kb(os.path.join(
                args.data_dir,
                "train_data_{}.txt".format(args.train_data_type)),
                          em_map=em_map,
                          rm_map=rm_map)
            # train_kb = kb(os.path.join(args.data_dir,"train_data_thorough_r_sorted.txt"), em_map = em_map, rm_map = rm_map)
            # train_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map)

        else:
            train_kb = kb(os.path.join(args.data_dir, "train.txt"),
                          em_map=em_map,
                          rm_map=rm_map)

        train_data = Dataset(train_kb.triples)
        train_sampler = RandomSampler(train_data, replacement=False)
        #train_sampler = SequentialSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        # crossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='mean')
        BCEloss = torch.nn.BCEWithLogitsLoss(reduction='sum')

        for epoch in tqdm(range(0, args.num_train_epochs), desc="epoch"):
            iteration = 0
            for train_e1_batch, train_r_batch, train_e2_batch in tqdm(
                    train_dataloader, desc="Train dataloader"):
                # skip this batch
                if (random.random() < args.skip_train_prob):
                    continue
                batch_size = len(train_e1_batch)
                train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices(
                    train_e1_batch, etoken_map, maxlen=args.max_seq_length)
                train_r_mention_tensor, train_r_lengths = convert_string_to_indices(
                    train_r_batch, rtoken_map, maxlen=args.max_seq_length)
                train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices(
                    train_e2_batch, etoken_map, maxlen=args.max_seq_length)

                train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(
                ), train_e1_lengths.cuda()
                train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda(
                ), train_r_lengths.cuda()
                train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(
                ), train_e2_lengths.cuda()

                e1_real_lstm, e1_img_lstm = model.get_mention_embedding(
                    train_e1_mention_tensor, 0, train_e1_lengths)
                r_real_lstm, r_img_lstm = model.get_mention_embedding(
                    train_r_mention_tensor, 1, train_r_lengths)
                e2_real_lstm, e2_img_lstm = model.get_mention_embedding(
                    train_e2_mention_tensor, 0, train_e2_lengths)
                #tail
                simi_t = model.complex_score_e1_r_with_all_ementions(
                    e1_real_lstm, e1_img_lstm, r_real_lstm, r_img_lstm,
                    e2_real_lstm, e2_img_lstm)
                #head
                simi_h = model.complex_score_e2_r_with_all_ementions(
                    e2_real_lstm, e2_img_lstm, r_real_lstm, r_img_lstm,
                    e1_real_lstm, e1_img_lstm)
                # change the loss suitably
                target = torch.eye(batch_size).cuda()
                # import pdb
                # pdb.set_trace()
                loss_t = BCEloss(simi_t.view(-1), target.view(-1))
                loss_h = BCEloss(simi_h.view(-1), target.view(-1))
                loss = (loss_h + loss_t) / 2
                loss /= target.size(0) * target.size(1)

                # Do the routine
                optimizer.zero_grad()
                loss.backward()
                #gradient clip?
                optimizer.step()

                if (iteration % args.print_loss_every == 0):
                    print("Current loss(avg, tail, head):", loss.item(),
                          loss_t.item(), loss_h.item())
                iteration += 1
            if (epoch % args.save_model_every == 0 and epoch != 0):
                utils.save_checkpoint(
                    {
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    args.output_dir + "/checkpoint_epoch_{}".format(epoch + 1))

    if args.do_eval:
        #eval code
        metrics = {}
        metrics['mr'] = 0
        metrics['mrr'] = 0
        metrics['hits1'] = 0
        metrics['hits10'] = 0
        metrics['hits50'] = 0
        metrics['mr_t'] = 0
        metrics['mrr_t'] = 0
        metrics['hits1_t'] = 0
        metrics['hits10_t'] = 0
        metrics['hits50_t'] = 0
        metrics['mr_h'] = 0
        metrics['mrr_h'] = 0
        metrics['hits1_h'] = 0
        metrics['hits10_h'] = 0
        metrics['hits50_h'] = 0

        if (args.resume and not args.do_train):
            print("Resuming from:", args.resume)
            checkpoint = torch.load(args.resume,
                                    map_location=lambda storage, loc: storage)
            model.load_state_dict(checkpoint['state_dict'])

        model.eval()

        # get embeddings for all entity mentions
        entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(
            entity_mentions,
            etoken_map,
            maxlen=args.max_seq_length,
            use_tqdm=True)
        entity_mentions_tensor = entity_mentions_tensor.cuda()
        entity_mentions_lengths = entity_mentions_lengths.cuda()

        ementions_real_lis = []
        ementions_img_lis = []
        split = 100  #cant fit all in gpu together. hence split
        with torch.no_grad():
            for i in tqdm(
                    range(0, len(entity_mentions_tensor),
                          len(entity_mentions_tensor) // split)):
                data = entity_mentions_tensor[i:i +
                                              len(entity_mentions_tensor) //
                                              split, :]
                data_lengths = entity_mentions_lengths[
                    i:i + len(entity_mentions_tensor) // split]
                ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding(
                    data, 0, data_lengths)
                # a = model.Et_im(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:])
                # b = model.Et_re(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:])

                # a_lstm,_ = model.lstm(a)
                # a_lstm = a_lstm[:,-1,:]

                # b_lstm,_ = model.lstm(b)
                # b_lstm = b_lstm[:,-1,:]

                ementions_real_lis.append(ementions_real_lstm.cpu())
                ementions_img_lis.append(ementions_img_lstm.cpu())
        del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm
        torch.cuda.empty_cache()
        ementions_real = torch.cat(ementions_real_lis).cuda()
        ementions_img = torch.cat(ementions_img_lis).cuda()
        ########################################################################
        if "olpbench" in args.data_dir:
            # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map)
            test_kb = kb(os.path.join(args.data_dir, "test_data.txt"),
                         em_map=em_map,
                         rm_map=rm_map)
        else:
            test_kb = kb(os.path.join(args.data_dir, "test.txt"),
                         em_map=em_map,
                         rm_map=rm_map)

        print("Loading all_known pickled data...(takes times since large)")
        all_known_e2 = {}
        all_known_e1 = {}
        all_known_e2, all_known_e1 = pickle.load(
            open(
                os.path.join(
                    args.data_dir,
                    "all_knowns_{}_linked.pkl".format(args.train_data_type)),
                "rb"))

        test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(
            test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length)
        test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(
            test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length)
        test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(
            test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length)

        # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map)
        indices = torch.Tensor(
            range(len(test_kb.triples))
        )  #indices would be used to fetch alternative answers while evaluating
        test_data = TensorDataset(indices, test_e1_tokens_tensor,
                                  test_r_tokens_tensor, test_e2_tokens_tensor,
                                  test_e1_tokens_lengths,
                                  test_r_tokens_lengths,
                                  test_e2_tokens_lengths)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.eval_batch_size)
        split_dim_for_eval = 1
        if (args.embedding_dim >= 256 and "olpbench" in args.data_dir
                and "rotat" in args.model):
            split_dim_for_eval = 4
        if (args.embedding_dim >= 512 and "olpbench" in args.data_dir):
            split_dim_for_eval = 4
        if (args.embedding_dim >= 512 and "olpbench" in args.data_dir
                and "rotat" in args.model):
            split_dim_for_eval = 6
        split_dim_for_eval = 1
        for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(
                test_dataloader, desc="Test dataloader"):
            print(metrics)
            test_e1_tokens, test_e1_lengths = test_e1_tokens.to(
                device), test_e1_lengths.to(device)
            test_r_tokens, test_r_lengths = test_r_tokens.to(
                device), test_r_lengths.to(device)
            test_e2_tokens, test_e2_lengths = test_e2_tokens.to(
                device), test_e2_lengths.to(device)
            with torch.no_grad():
                e1_real_lstm, e1_img_lstm = model.get_mention_embedding(
                    test_e1_tokens, 0, test_e1_lengths)
                r_real_lstm, r_img_lstm = model.get_mention_embedding(
                    test_r_tokens, 1, test_r_lengths)
                e2_real_lstm, e2_img_lstm = model.get_mention_embedding(
                    test_e2_tokens, 0, test_e2_lengths)

            for count in tqdm(range(index.shape[0]), desc="Evaluating"):
                # breakpoint()
                this_e1_real = e1_real_lstm[count].unsqueeze(0)
                this_e1_img = e1_img_lstm[count].unsqueeze(0)
                this_r_real = r_real_lstm[count].unsqueeze(0)
                this_r_img = r_img_lstm[count].unsqueeze(0)
                this_e2_real = e2_real_lstm[count].unsqueeze(0)
                this_e2_img = e2_img_lstm[count].unsqueeze(0)
                # import pdb
                # pdb.set_trace()
                simi_t = model.complex_score_e1_r_with_all_ementions(
                    this_e1_real,
                    this_e1_img,
                    this_r_real,
                    this_r_img,
                    ementions_real,
                    ementions_img,
                    split=split_dim_for_eval).squeeze(0)
                simi_h = model.complex_score_e2_r_with_all_ementions(
                    this_e2_real,
                    this_e2_img,
                    this_r_real,
                    this_r_img,
                    ementions_real,
                    ementions_img,
                    split=split_dim_for_eval).squeeze(0)
                # get known answers for filtered ranking
                ind = index[count]
                this_correct_mentions_e2 = test_kb.e2_all_answers[int(
                    ind.item())]
                this_correct_mentions_e1 = test_kb.e1_all_answers[int(
                    ind.item())]

                all_correct_mentions_e2 = all_known_e2.get(
                    (em_map[test_kb.triples[int(ind.item())][0]],
                     rm_map[test_kb.triples[int(ind.item())][1]]), [])
                all_correct_mentions_e1 = all_known_e1.get(
                    (em_map[test_kb.triples[int(ind.item())][2]],
                     rm_map[test_kb.triples[int(ind.item())][1]]), [])

                # compute metrics
                best_score = simi_t[this_correct_mentions_e2].max()
                simi_t[
                    all_correct_mentions_e2] = -20000000  # MOST NEGATIVE VALUE
                greatereq = simi_t.ge(best_score).float()
                equal = simi_t.eq(best_score).float()
                rank = greatereq.sum() + 1 + equal.sum() / 2.0

                metrics['mr_t'] += rank
                metrics['mrr_t'] += 1.0 / rank
                metrics['hits1_t'] += rank.le(1).float()
                metrics['hits10_t'] += rank.le(10).float()
                metrics['hits50_t'] += rank.le(50).float()

                best_score = simi_h[this_correct_mentions_e1].max()
                simi_h[
                    all_correct_mentions_e1] = -20000000  # MOST NEGATIVE VALUE
                greatereq = simi_h.ge(best_score).float()
                equal = simi_h.eq(best_score).float()
                rank = greatereq.sum() + 1 + equal.sum() / 2.0
                metrics['mr_h'] += rank
                metrics['mrr_h'] += 1.0 / rank
                metrics['hits1_h'] += rank.le(1).float()
                metrics['hits10_h'] += rank.le(10).float()
                metrics['hits50_h'] += rank.le(50).float()

                metrics['mr'] = (metrics['mr_h'] + metrics['mr_t']) / 2
                metrics['mrr'] = (metrics['mrr_h'] + metrics['mrr_t']) / 2
                metrics['hits1'] = (metrics['hits1_h'] +
                                    metrics['hits1_t']) / 2
                metrics['hits10'] = (metrics['hits10_h'] +
                                     metrics['hits10_t']) / 2
                metrics['hits50'] = (metrics['hits50_h'] +
                                     metrics['hits50_t']) / 2

        for key in metrics:
            metrics[key] = metrics[key] / len(test_kb.triples)
        print(metrics)
示例#16
0
def main():
    data_dir = "data/olpbench"
    head_or_tail = "tail"
    sample = 100
    train_kb = kb(os.path.join(data_dir, "train_data_thorough.txt"),
                  em_map=None,
                  rm_map=None)
    freq_r = {}
    freq_r_e2 = {}
    for triple in train_kb.triples:
        e1 = triple[0].item()
        r = triple[1].item()
        e2 = triple[2].item()
        if r not in freq_r:
            freq_r[r] = 0
        freq_r[r] += 1

        # if (r,e2) not in freq_r_e2:
        # 	freq_r_e2[(r,e2)] = 0
        # freq_r_e2[(r,e2)] += 1

        if r not in freq_r_e2:
            freq_r_e2[r] = {}
        if e2 not in freq_r_e2[r]:
            freq_r_e2[r][e2] = 0
        freq_r_e2[r][e2] += 1

    pred_file = "helper_scripts/keshav_xt-preds/validation_data_linked_mention.txt.tail_thorough_n2_e50_lr0.1.stage1"
    _, rm_map = utils.read_mentions(
        os.path.join(data_dir, "mapped_to_ids", "relation_id_map.txt"))
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(data_dir, "mapped_to_ids", "entity_id_map.txt"))

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)

    print("Loading all_known pickled data...(takes times since large)")
    all_known_e2 = {}
    all_known_e1 = {}
    all_known_e2, all_known_e1 = pickle.load(
        open(os.path.join(data_dir, "all_knowns_thorough_linked.pkl"), "rb"))

    hits_1_triple = []
    hits_1_correct_answers = []
    hits_1_model_top10 = []

    nothits_50_triple = []
    nothits_50_correct_answers = []
    nothits_50_model_top10 = []

    ranks = []
    lines = open(pred_file).readlines()
    for line in tqdm(lines, desc="preds"):
        line = line.strip().split("\t")
        e1 = line[0]
        r = line[1]
        e2 = line[2]

        this_correct_mentions_e2_raw = line[4].split("|||")
        this_correct_mentions_e2 = []
        for mention in this_correct_mentions_e2_raw:
            if mention in em_map:
                this_correct_mentions_e2.append(em_map[mention])
        this_correct_mentions_e1_raw = line[3].split("|||")
        this_correct_mentions_e1 = []
        for mention in this_correct_mentions_e1_raw:
            if mention in em_map:
                this_correct_mentions_e1.append(em_map[mention])

        all_correct_mentions_e2 = all_known_e2.get((em_map[e1], rm_map[r]), [])
        all_correct_mentions_e1 = all_known_e1.get((em_map[e2], rm_map[r]), [])

        indices_scores = torch.tensor(ast.literal_eval(line[5]))
        topk_scores = indices_scores[:, 1]
        indices = indices_scores[:, 0].long()
        if (head_or_tail == "tail"):
            this_gold = this_correct_mentions_e2
            all_gold = all_correct_mentions_e2
        else:
            this_gold = this_correct_mentions_e1
            all_gold = all_correct_mentions_e1

        best_score = -2000000000
        for i, j in enumerate(indices):
            if j in this_gold:
                best_score = max(best_score, topk_scores[i].item())
                topk_scores[i] = -2000000000

        for i, j in enumerate(indices):
            if j in all_gold:
                topk_scores[i] = -2000000000
        greatereq = topk_scores.ge(best_score).float()
        equal = topk_scores.eq(best_score).float()
        rank = (greatereq.sum() + 1 + equal.sum() / 2.0).item()
        if rank <= 1:
            hits_1_triple.append([e1, r, e2])
            hits_1_correct_answers.append(
                [entity_mentions[x] for x in this_gold])
            hits_1_model_top10.append([])
        elif rank > 50:
            nothits_50_triple.append([e1, r, e2])
            nothits_50_correct_answers.append(
                [entity_mentions[x] for x in this_gold])
            nothits_50_model_top10.append(
                [entity_mentions[x.item()] for x in indices])
        ranks.append(rank)

    result = {}
    result["hits1"] = 0
    result["hits10"] = 0
    result["hits50"] = 0
    for rank in ranks:
        if rank <= 1:
            result["hits1"] += 1
        if rank <= 10:
            result["hits10"] += 1
        if rank <= 50:
            result["hits50"] += 1
    result["hits1"] /= len(lines)
    result["hits10"] /= len(lines)
    result["hits50"] /= len(lines)
    print(result)

    # print format
    # triple, correct answers, model predictions, r: freq, r_gold: freq, r_prediction: freq, r_max-e2: freq
    indices = list(range(len(hits_1_triple)))
    random.shuffle(indices)
    indices = indices[:sample]
    for ind in indices:
        # ratio = " {} /{}".format(freq_r_e2.get((hits_1_triple[ind][1],hits_1_triple[ind][2]),0),freq_r.get(hits_1_triple[ind][1],0))
        freq_of_r = freq_r.get(hits_1_triple[ind][1], 0)
        freq_of_r_gold = freq_r_e2.get(hits_1_triple[ind][1],
                                       {}).get(hits_1_triple[ind][2], 0)
        freq_of_r_pred = "N/A"
        freq_of_r_maxe = max(
            list(freq_r_e2.get(hits_1_triple[ind][1], {
                0: 0
            }).values()))
        print(hits_1_triple[ind], "|", hits_1_correct_answers[ind], "|",
              hits_1_model_top10[ind], "|", freq_of_r, "|", freq_of_r_gold,
              "|", freq_of_r_pred, "|", freq_of_r_maxe)
    print(
        "---------------------------------------------------------------------------------------------"
    )
    indices = list(range(len(nothits_50_triple)))
    random.shuffle(indices)
    indices = indices[:sample]
    for ind in indices:
        # ratio = " {} /{}".format(freq_r_e2.get((nothits_50_triple[ind][1],nothits_50_triple[ind][2]),0),freq_r.get(nothits_50_triple[ind][1],0))
        freq_of_r = freq_r.get(nothits_50_triple[ind][1], 0)
        freq_of_r_gold = freq_r_e2.get(nothits_50_triple[ind][1],
                                       {}).get(nothits_50_triple[ind][2], 0)
        freq_of_r_pred = freq_r_e2.get(nothits_50_triple[ind][1],
                                       {}).get(nothits_50_model_top10[ind][0],
                                               0)
        freq_of_r_maxe = max(
            list(freq_r_e2.get(nothits_50_triple[ind][1], {
                0: 0
            }).values()))
        print(nothits_50_triple[ind], "|", nothits_50_correct_answers[ind],
              "|", nothits_50_model_top10[ind], "|", freq_of_r, "|",
              freq_of_r_gold, "|", freq_of_r_pred, "|", freq_of_r_maxe)
def main(args):
	random.seed(args.seed)
	np.random.seed(args.seed)
	torch.manual_seed(args.seed)
	
	# for batch in train_loader:
	# 	inputs, \
 #        normalizer_loss, \
 #        normalizer_metric, \
 #        labels, \
 #        label_ids, \
 #        filter_mask, \
 #        batch_shared_entities = train_data.input_and_labels_to_device(
 #            batch,
 #            training=True,
 #            device=train_data.device
 #        )
	# 	import pdb
	# 	pdb.set_trace()


	# read token maps
	etokens, etoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","entity_token_id_map.txt"))
	rtokens, rtoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","relation_token_id_map.txt"))
	entity_mentions,em_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","entity_id_map.txt"))
	relation_mentions,rm_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","relation_id_map.txt"))

	# create entity_token_indices and entity_lengths
	# [[max length indices for entity 0 ], [max length indices for entity 1], [max length indices for entity 2], ...]
	# [length of entity 0, length of entity 1, length of entity 2, ...]
	entity_token_indices, entity_lengths = utils.get_token_indices_from_mention_indices(entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True)
	relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(relation_mentions, rtoken_map, maxlen=args.max_seq_length, use_tqdm=True)

	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

	if not args.do_train and not args.do_eval:
		raise ValueError("At least one of `do_train` or `do_eval` must be True.")

	#train code (+1 for unk token)
	if args.model=="complex":
		if args.separate_lstms:
			model = complexLSTM_2(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, lstm_dropout=args.lstm_dropout)
		else:
			model = complexLSTM(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, lstm_dropout=args.lstm_dropout)
	elif args.model == "rotate":
		model = rotatELSTM(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, gamma = args.gamma_rotate, lstm_dropout=args.lstm_dropout)
	if(args.do_train):
		data_config = {'input_file': 'train_data_thorough.txt', 'batch_size': args.train_batch_size, 'use_batch_shared_entities': True, 'min_size_batch_labels': args.train_batch_size, 'max_size_prefix_label': 64, 'device': 0}
		expt_settings = {'loss': 'bce', 'replace_entities_by_tokens': True, 'replace_relations_by_tokens': True, 'max_lengths_tuple': [10, 10]}
		train_data = OneToNMentionRelationDataset(dataset_dir=os.path.join(args.data_dir,"mapped_to_ids"), is_training_data=True, **data_config, **expt_settings)
		train_data.create_data_tensors(
			dataset_dir=os.path.join(args.data_dir,"mapped_to_ids"),
			train_input_file='train_data_thorough.txt',
			valid_input_file='validation_data_linked.txt',
			test_input_file='test_data.txt',
		)
		train_loader = train_data.get_loader(
			shuffle=True,
			num_workers=8,
			drop_last=True,
		)
		optimizer = torch.optim.Adagrad(model.parameters(),lr=args.learning_rate,weight_decay=args.weight_decay)

		if(args.resume):
			print("Resuming from:",args.resume)
			checkpoint = torch.load(args.resume)
			model.load_state_dict(checkpoint['state_dict'])
			optimizer.load_state_dict(checkpoint['optimizer'])
			#Load other things too if required

		model.train()
		# if "olpbench" in args.data_dir:
		# 	train_kb = kb(os.path.join(args.data_dir,"train_data_{}.txt".format(args.train_data_type)), em_map = em_map, rm_map = rm_map)
		# 	# train_kb = kb(os.path.join(args.data_dir,"train_data_thorough_r_sorted.txt"), em_map = em_map, rm_map = rm_map)
		# 	# train_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map)

		# else:
		# 	train_kb = kb(os.path.join(args.data_dir,"train.txt"), em_map = em_map, rm_map = rm_map)
		
		# train_data = Dataset(train_kb.triples)
		# train_sampler = RandomSampler(train_data,replacement=False)
		# #train_sampler = SequentialSampler(train_data)
		# train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

		# crossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='mean')
		BCEloss = torch.nn.BCEWithLogitsLoss(reduction='sum')
		for epoch in tqdm(range(0,args.num_train_epochs), desc="epoch"):
			iteration = 0
			for batch in tqdm(train_loader, desc="Train dataloader"):
				inputs, \
				normalizer_loss, \
				normalizer_metric, \
				labels, \
				label_ids, \
				filter_mask, \
				batch_shared_entities = train_data.input_and_labels_to_device(
					batch,
					training=True,
					device="cpu"
				)
				labels = labels.cuda()
				all_outputs = []
				for mode,model_inputs in zip(["head","tail"],inputs):
					if model_inputs==None:
						continue
					# subtract two from author's indices because our map is 2 less
					if mode=="head":
						batch_e2_indices = model_inputs[1] - 2
						batch_r_indices  = model_inputs[0] - 2
						batch_e1_indices = batch_shared_entities - 2
					else:
						batch_e1_indices = model_inputs[0] - 2
						batch_r_indices  = model_inputs[1] - 2
						batch_e2_indices = batch_shared_entities - 2
					# import pdb 
					# pdb.set_trace()

					# convert these indices back into string (compatibility with my code)
					# tik = time.time()
					# batch_e1_strings = convert_mention_index_to_string(batch_e1_indices.squeeze(1), entity_mentions)
					# batch_r_strings  = convert_mention_index_to_string(batch_r_indices.squeeze(1), relation_mentions)
					# batch_e2_strings = convert_mention_index_to_string(batch_e2_indices.squeeze(1), entity_mentions)
					# # print("convert_mention_index_to_string:",time.time() - tik)
					# # do what you used to do now
					# # tik = time.time()
					# train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices(batch_e1_strings,etoken_map,maxlen=args.max_seq_length)
					# train_r_mention_tensor, train_r_lengths   = convert_string_to_indices(batch_r_strings,rtoken_map,maxlen=args.max_seq_length)
					# train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices(batch_e2_strings,etoken_map,maxlen=args.max_seq_length)
					# print("convert_string_to_indices:",time.time() - tik)
					
					train_e1_mention_tensor, train_e1_lengths = convert_mention_to_token_indices(batch_e1_indices.squeeze(1), entity_token_indices, entity_lengths)
					train_r_mention_tensor, train_r_lengths   = convert_mention_to_token_indices(batch_r_indices.squeeze(1), relation_token_indices, relation_lengths)
					train_e2_mention_tensor, train_e2_lengths = convert_mention_to_token_indices(batch_e2_indices.squeeze(1), entity_token_indices, entity_lengths)


					train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(), train_e1_lengths.cuda()
					train_r_mention_tensor, train_r_lengths   = train_r_mention_tensor.cuda(), train_r_lengths.cuda()
					train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(), train_e2_lengths.cuda()

					# tik = time.time()
					e1_real_lstm, e1_img_lstm = model.get_mention_embedding(train_e1_mention_tensor,0,train_e1_lengths)
					r_real_lstm, r_img_lstm   = model.get_mention_embedding(train_r_mention_tensor,1,train_r_lengths)
					e2_real_lstm, e2_img_lstm = model.get_mention_embedding(train_e2_mention_tensor,0,train_e2_lengths)
					# print("get_mention_embedding:",time.time() - tik)

					# tik = time.time()
					if mode=="head":
						output = model.complex_score_e2_r_with_all_ementions(e2_real_lstm,e2_img_lstm,r_real_lstm,r_img_lstm,e1_real_lstm,e1_img_lstm)
					else:
						output = model.complex_score_e1_r_with_all_ementions(e1_real_lstm,e1_img_lstm,r_real_lstm,r_img_lstm,e2_real_lstm,e2_img_lstm)
					# print("model_scoring:",time.time() - tik)

					all_outputs.append(output)
				
				all_outputs = torch.cat(all_outputs)
				loss = BCEloss(all_outputs.view(-1),labels.view(-1))
				# loss = loss.sum()
				loss /= normalizer_loss
				optimizer.zero_grad()
				loss.backward()
				optimizer.step()
				if(iteration%args.print_loss_every==0):
					print("Current loss:",loss.item())
				iteration+=1
			if(epoch%args.save_model_every==0):
				utils.save_checkpoint({
						'state_dict':model.state_dict(),
						'optimizer':optimizer.state_dict()
						},args.output_dir+"/checkpoint_epoch_{}".format(epoch+1))

		# for epoch in tqdm(range(0,args.num_train_epochs),desc="epoch"):
		# 	iteration = 0
		# 	for train_e1_batch, train_r_batch, train_e2_batch in tqdm(train_dataloader,desc="Train dataloader"):
		# 		# skip this batch
		# 		if(random.random()<args.skip_train_prob):
		# 			continue	
		# 		batch_size = len(train_e1_batch)
		# 		train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices(train_e1_batch,etoken_map,maxlen=args.max_seq_length)
		# 		train_r_mention_tensor, train_r_lengths = convert_string_to_indices(train_r_batch,rtoken_map,maxlen=args.max_seq_length)
		# 		train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices(train_e2_batch,etoken_map,maxlen=args.max_seq_length)

		# 		train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(), train_e1_lengths.cuda()
		# 		train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda(), train_r_lengths.cuda()
		# 		train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(), train_e2_lengths.cuda()


		# 		e1_real_lstm, e1_img_lstm = model.get_mention_embedding(train_e1_mention_tensor,0,train_e1_lengths)
		# 		r_real_lstm, r_img_lstm = model.get_mention_embedding(train_r_mention_tensor,1,train_r_lengths)
		# 		e2_real_lstm, e2_img_lstm = model.get_mention_embedding(train_e2_mention_tensor,0,train_e2_lengths)
		# 		#tail
		# 		simi_t = model.complex_score_e1_r_with_all_ementions(e1_real_lstm,e1_img_lstm,r_real_lstm,r_img_lstm,e2_real_lstm,e2_img_lstm)
		# 		#head
		# 		simi_h = model.complex_score_e2_r_with_all_ementions(e2_real_lstm,e2_img_lstm,r_real_lstm,r_img_lstm,e1_real_lstm,e1_img_lstm)
		# 		# change the loss suitably
		# 		target = torch.eye(batch_size).cuda()
		# 		# import pdb
		# 		# pdb.set_trace()
		# 		loss_t = BCEloss(simi_t.view(-1),target.view(-1))
		# 		loss_h = BCEloss(simi_h.view(-1),target.view(-1))
		# 		loss = (loss_h+loss_t)/2
		# 		loss /= target.size(0) * target.size(1)

		# 		# Do the routine
		# 		optimizer.zero_grad()
		# 		loss.backward()
		# 		#gradient clip?
		# 		optimizer.step()

		# 		if(iteration%args.print_loss_every==0):
		# 			print("Current loss(avg, tail, head):",loss.item(), loss_t.item(), loss_h.item())
		# 		iteration+=1
		# 	if(epoch%args.save_model_every==0 and epoch!=0):
		# 		utils.save_checkpoint({
		# 				'state_dict':model.state_dict(),
		# 				'optimizer':optimizer.state_dict()
		# 				},args.output_dir+"/checkpoint_epoch_{}".format(epoch+1))


	if args.do_eval:
		#eval code
		metrics = {}
		metrics['mr'] = 0
		metrics['mrr'] = 0
		metrics['hits1'] = 0
		metrics['hits10'] = 0
		metrics['hits50'] = 0
		metrics['mr_t'] = 0
		metrics['mrr_t'] = 0
		metrics['hits1_t'] = 0
		metrics['hits10_t'] = 0
		metrics['hits50_t'] = 0
		metrics['mr_h'] = 0
		metrics['mrr_h'] = 0
		metrics['hits1_h'] = 0
		metrics['hits10_h'] = 0
		metrics['hits50_h'] = 0

		if(args.resume and not args.do_train):
			print("Resuming from:",args.resume)
			checkpoint = torch.load(args.resume,map_location=lambda storage, loc: storage)
			model.load_state_dict(checkpoint['state_dict'])

		model.eval()

		# get embeddings for all entity mentions
		entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(entity_mentions,etoken_map,maxlen=args.max_seq_length,use_tqdm=True)
		entity_mentions_tensor = entity_mentions_tensor.cuda()
		entity_mentions_lengths = entity_mentions_lengths.cuda()

		ementions_real_lis = []
		ementions_img_lis = []
		split = 100 #cant fit all in gpu together. hence split
		with torch.no_grad():
			for i in tqdm(range(0,len(entity_mentions_tensor),len(entity_mentions_tensor)//split)):
				data = entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]
				data_lengths = entity_mentions_lengths[i:i+len(entity_mentions_tensor)//split]
				ementions_real_lstm,ementions_img_lstm = model.get_mention_embedding(data,0,data_lengths)			
				# a = model.Et_im(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:])
				# b = model.Et_re(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:])
				
				# a_lstm,_ = model.lstm(a)
				# a_lstm = a_lstm[:,-1,:]

				
				# b_lstm,_ = model.lstm(b)
				# b_lstm = b_lstm[:,-1,:]

				ementions_real_lis.append(ementions_real_lstm.cpu())
				ementions_img_lis.append(ementions_img_lstm.cpu())
		del entity_mentions_tensor,ementions_real_lstm,ementions_img_lstm
		torch.cuda.empty_cache()
		ementions_real = torch.cat(ementions_real_lis).cuda()
		ementions_img = torch.cat(ementions_img_lis).cuda()
		########################################################################
		if "olpbench" in args.data_dir:
			# test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map)
			test_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map)
		else:
			test_kb = kb(os.path.join(args.data_dir,"test.txt"), em_map = em_map, rm_map = rm_map)

		print("Loading all_known pickled data...(takes times since large)")
		all_known_e2 = {}
		all_known_e1 = {}
		all_known_e2,all_known_e1 = pickle.load(open(os.path.join(args.data_dir,"all_knowns_{}_linked.pkl".format(args.train_data_type)),"rb"))


		test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(test_kb.triples[:,0], etoken_map,maxlen=args.max_seq_length)
		test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(test_kb.triples[:,1], rtoken_map,maxlen=args.max_seq_length)
		test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(test_kb.triples[:,2], etoken_map,maxlen=args.max_seq_length)
		
		# e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map)
		indices = torch.Tensor(range(len(test_kb.triples))) #indices would be used to fetch alternative answers while evaluating
		test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths)
		test_sampler = SequentialSampler(test_data)
		test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size)
		split_dim_for_eval = 1
		if(args.embedding_dim>=256 and "olpbench" in args.data_dir and "rotat" in args.model):
			split_dim_for_eval = 4
		if(args.embedding_dim>=512 and "olpbench" in args.data_dir):
			split_dim_for_eval = 4
		if(args.embedding_dim>=512 and "olpbench" in args.data_dir and "rotat" in args.model):
			split_dim_for_eval = 6
		split_dim_for_eval = 1
		for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(test_dataloader,desc="Test dataloader"):
			print(metrics)
			test_e1_tokens, test_e1_lengths = test_e1_tokens.to(device), test_e1_lengths.to(device)
			test_r_tokens, test_r_lengths = test_r_tokens.to(device), test_r_lengths.to(device)
			test_e2_tokens, test_e2_lengths = test_e2_tokens.to(device), test_e2_lengths.to(device)
			with torch.no_grad():
				e1_real_lstm, e1_img_lstm = model.get_mention_embedding(test_e1_tokens,0, test_e1_lengths)
				r_real_lstm, r_img_lstm = model.get_mention_embedding(test_r_tokens,1, test_r_lengths)	
				e2_real_lstm, e2_img_lstm = model.get_mention_embedding(test_e2_tokens,0, test_e2_lengths)


			for count in tqdm(range(index.shape[0]), desc="Evaluating"):
				# breakpoint()
				this_e1_real = e1_real_lstm[count].unsqueeze(0)
				this_e1_img  = e1_img_lstm[count].unsqueeze(0)
				this_r_real  = r_real_lstm[count].unsqueeze(0)
				this_r_img   = r_img_lstm[count].unsqueeze(0)
				this_e2_real = e2_real_lstm[count].unsqueeze(0)
				this_e2_img  = e2_img_lstm[count].unsqueeze(0)
				simi_t = model.complex_score_e1_r_with_all_ementions(this_e1_real,this_e1_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0)
				simi_h = model.complex_score_e2_r_with_all_ementions(this_e2_real,this_e2_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0)
				# get known answers for filtered ranking
				ind = index[count]
				this_correct_mentions_e2 = test_kb.e2_all_answers[int(ind.item())]
				this_correct_mentions_e1 = test_kb.e1_all_answers[int(ind.item())] 

				all_correct_mentions_e2 = all_known_e2.get((em_map[test_kb.triples[int(ind.item())][0]],rm_map[test_kb.triples[int(ind.item())][1]]),[])
				all_correct_mentions_e1 = all_known_e1.get((em_map[test_kb.triples[int(ind.item())][2]],rm_map[test_kb.triples[int(ind.item())][1]]),[])
				
				# compute metrics
				best_score = simi_t[this_correct_mentions_e2].max()
				simi_t[all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE
				greatereq = simi_t.ge(best_score).float()
				equal = simi_t.eq(best_score).float()
				rank = greatereq.sum()+1+equal.sum()/2.0

				metrics['mr_t'] += rank
				metrics['mrr_t'] += 1.0/rank
				metrics['hits1_t'] += rank.le(1).float()
				metrics['hits10_t'] += rank.le(10).float()
				metrics['hits50_t'] += rank.le(50).float()

				best_score = simi_h[this_correct_mentions_e1].max()
				simi_h[all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE
				greatereq = simi_h.ge(best_score).float()
				equal = simi_h.eq(best_score).float()
				rank = greatereq.sum()+1+equal.sum()/2.0
				metrics['mr_h'] += rank
				metrics['mrr_h'] += 1.0/rank
				metrics['hits1_h'] += rank.le(1).float()
				metrics['hits10_h'] += rank.le(10).float()
				metrics['hits50_h'] += rank.le(50).float()

				metrics['mr'] = (metrics['mr_h']+metrics['mr_t'])/2
				metrics['mrr'] = (metrics['mrr_h']+metrics['mrr_t'])/2
				metrics['hits1'] = (metrics['hits1_h']+metrics['hits1_t'])/2
				metrics['hits10'] = (metrics['hits10_h']+metrics['hits10_t'])/2
				metrics['hits50'] = (metrics['hits50_h']+metrics['hits50_t'])/2

		for key in metrics:
			metrics[key] = metrics[key] / len(test_kb.triples)
		print(metrics)