Exemplo n.º 1
0
relation_vec_E = rel_vec_E[relation2id[rel], :]

ent_vec_R = np.loadtxt(dataPath_ + '/entity2vec.bern')
rel_vec_R = np.loadtxt(dataPath_ + '/relation2vec.bern')
M = np.loadtxt(dataPath_ + '/A.bern')
M = M.reshape([-1, 50, 50])
relation_vec_R = rel_vec_R[relation2id[rel], :]
M_vec = M[relation2id[rel], :, :]

_, named_paths = get_features()
path_weights = []
for path in named_paths:
    weight = 1.0 / len(path)
    path_weights.append(weight)
path_weights = np.array(path_weights)
kb = KB()
kb_inv = KB()

f = open(dataPath_ + '/graph.txt')
kb_lines = f.readlines()
f.close()

for line in kb_lines:
    e1 = line.split()[0]
    rel = line.split()[1]
    e2 = line.split()[2]
    kb.addRelation(e1, rel, e2)
    kb_inv.addRelation(e2, rel, e1)

f = open(test_data_path)
test_data = f.readlines()
Exemplo n.º 2
0
def evaluate_logic():  # using in main()  return RL MAP
    kb = KB()  # class KB form BFS.KB.py
    kb_inv = KB()  # class KB form BFS.KB.py

    f = open(dataPath_ + '/graph.txt')
    kb_lines = f.readlines()
    f.close()

    for line in kb_lines:
        e1 = line.split()[0]
        rel = line.split()[1]
        e2 = line.split()[2]
        kb.addRelation(e1, rel, e2)
        kb_inv.addRelation(e2, rel, e1)

    _, named_paths = get_features()

    model = train(kb, kb_inv, named_paths)

    f = open(dataPath_ +
             '/sort_test.pairs')  # sort_test.txt in alphabetical order
    test_data = f.readlines()  # the test data for the whole model
    f.close()
    test_pairs = []
    test_labels = []
    # queries = set()
    for line in test_data:
        e1 = line.split(',')[0].replace('thing$', '')
        # e1 = '/' + e1[0] + '/' + e1[2:]
        e2 = line.split(',')[1].split(':')[0].replace('thing$', '')
        # e2 = '/' + e2[0] + '/' + e2[2:]
        if (e1 not in kb.entities) or (e2 not in kb.entities):
            continue
        test_pairs.append((e1, e2))
        label = 1 if line[-2] == '+' else 0
        test_labels.append(label)

    # print ('test_pairs:',test_pairs)
    # print ('test_labels:',test_labels)
    aps = []
    query = test_pairs[0][0]
    y_true = []
    y_score = []
    hit_1_list = []
    hit_3_list = []
    hit_10_list = []
    mrr_list = []

    score_all = []

    for idx, sample in enumerate(test_pairs):
        # print 'query node: ', sample[0], idx
        if sample[0] == query:
            # print 'query:',query
            # print 'sample:',sample[0]
            # print 'y_ture:',y_true
            features = []
            for path in named_paths:
                features.append(
                    int(bfs_two(sample[0], sample[1], path, kb, kb_inv)))

            # score is a np.array([[float32]], dtype=float32))
            score = model.predict(np.reshape(features, [1, -1]))

            # score = np.sum(features)
            # print ('score:',score)
            score_all.append(score[0])
            y_score.append(score)
            y_true.append(test_labels[idx])
        else:  # begin to next test batch
            # print 'query:',query
            # print 'sample:',sample[0]
            # print 'y_ture:', y_true
            # raw_input('----------')
            query = sample[0]
            # print (y_true)
            count = list(zip(y_score, y_true))
            count.sort(key=lambda x: x[0], reverse=True)
            # print ('count:',len(count))

            ranks = []
            correct = 0

            hit_1 = 0
            hit_3 = 0
            hit_10 = 0
            mrr = 0

            # almost every count only have correct item
            # because in sort_test.pairs almost 1+ with several - for every test_pair
            for idx_, item in enumerate(count):
                if item[1] == 1:
                    correct += 1
                    ranks.append(correct / (1.0 + idx_))

                    # only use the first positive sample to evaluate hits@n
                    if correct == 1:
                        if idx_ < 10:
                            hit_10 += 1
                            if idx_ < 3:
                                hit_3 += 1
                                if idx_ < 1:
                                    hit_1 += 1
                    if mrr == 0:
                        mrr = 1 / (1.0 + idx_)

            if len(ranks) == 0:
                aps.append(0)
            else:
                aps.append(np.mean(ranks))

            hit_1_list.append(hit_1)
            hit_3_list.append(hit_3)
            hit_10_list.append(hit_10)
            if correct == 0:
                mrr_list.append(0)
            else:
                mrr_list.append(mrr / correct)
            # print np.mean(ranks)
            # if len(aps) % 10 == 0:
            # 	print 'How many queries:', len(aps)
            # 	print np.mean(aps)
            y_true = []
            y_score = []
            features = []
            for path in named_paths:
                features.append(
                    int(bfs_two(sample[0], sample[1], path, kb,
                                kb_inv)))  # bfs_two returns True or False

            # features = features*path_weights
            # score = np.inner(features, path_weights)
            # score = np.sum(features)
            score = model.predict(np.reshape(features, [1, -1]))

            score_all.append(score[0])
            y_score.append(score)
            y_true.append(test_labels[idx])

    count = list(zip(y_score, y_true))
    count.sort(key=lambda x: x[0], reverse=True)
    # print ('count:',count)

    ranks = []
    correct = 0

    hit_1 = 0
    hit_3 = 0
    hit_10 = 0
    mrr = 0

    for idx_, item in enumerate(count):
        if item[1] == 1:
            correct += 1
            ranks.append(correct / (1.0 + idx_))

            # only use the first positive sample to evaluate hits@n
            if correct == 1:
                if idx_ < 10:
                    hit_10 += 1
                    if idx_ < 3:
                        hit_3 += 1
                        if idx_ < 1:
                            hit_1 += 1

            if mrr == 0:
                mrr = 1 / (1.0 + idx_)

        # if hit_10 > 1:
        #     print count
        #     raw_input('----------')

    # print (ranks)
    aps.append(np.mean(ranks))

    hit_1_list.append(hit_1)
    hit_3_list.append(hit_3)
    hit_10_list.append(hit_10)
    if correct == 0:
        mrr_list.append(0)
    else:
        mrr_list.append(mrr / correct)
    # score_label = zip(score_all, test_labels)
    # score_label_ranked = sorted(score_label, key=lambda x: x[0], reverse=True)
    # print ('score_label_ranked:',len(score_label_ranked))
    # print ('aps:', aps)

    # print hit_10_list

    mean_ap = np.mean(aps)
    mean_hit_1 = np.mean(hit_1_list)
    mean_hit_3 = np.mean(hit_3_list)
    mean_hit_10 = np.mean(hit_10_list)
    mean_mrr = np.mean(mrr_list)
    print 'RL MAP: ', mean_ap
    print 'HITS@1: ', mean_hit_1
    print 'HITS@3: ', mean_hit_3
    print 'HITS@10: ', mean_hit_10
    print 'MRR: ', mean_mrr

    with open(link_results_path, 'a') as f:
        f.write(relation + ':\n')
        f.write('RL MAP: ' + str(mean_ap) + '\n' + 'HITS@1: ' +
                str(mean_hit_1) + '\n' + 'HITS@3: ' + str(mean_hit_3) + '\n' +
                'HITS@10: ' + str(mean_hit_10) + '\n' + 'MRR: ' +
                str(mean_mrr) + '\n')
Exemplo n.º 3
0
def fact_prediction_eval_logic():
	f1 = open(ent_id_path)
	f2 = open(rel_id_path)
	content1 = f1.readlines()
	content2 = f2.readlines()
	f1.close()
	f2.close()

	entity2id = {}
	relation2id = {}
	for line in content1:
		entity2id[line.split()[0]] = int(line.split()[1])

	for line in content2:
		relation2id[line.split()[0]] = int(line.split()[1])

	_, named_paths, occurrence_paths = get_features(feature_stats)

	length_weights = []
	for path in named_paths:
		weight = 1.0/len(path)
		length_weights.append(weight)
	length_weights = np.array(length_weights)

	"""
	path_weights = [elem / sum(occurrence_paths) for elem in occurrence_paths]
	path_weights = np.array(path_weights)
	"""
	kb = KB()
	kb_inv = KB()

	f = open(dataPath_ + '/graph.txt')
	kb_lines = f.readlines()
	f.close()

	for line in kb_lines:
		e1 = line.split()[0]
		rel = line.split()[1]
		e2 = line.split()[2]
		kb.addRelation(e1,rel,e2)
		kb_inv.addRelation(e2,rel,e1)

	f = open(test_data_path)
	test_data = f.readlines()
	f.close()
	test_pairs = []
	test_labels = []
	test_set = set()
	for line in test_data:
		e1 = line.split(',')[0].replace('thing$','')
		#e1 = '/' + e1[0] + '/' + e1[2:]
		e2 = line.split(',')[1].split(':')[0].replace('thing$','')
		#e2 = '/' + e2[0] + '/' + e2[2:]
		#if (e1 not in kb.entities) or (e2 not in kb.entities):
		#	continue
		test_pairs.append((e1,e2))
		label = 1 if line[-2] == '+' else 0
		test_labels.append(label)

	scores_rl = []

	print ('How many queries: ', len(test_pairs))
	for idx, sample in enumerate(test_pairs):
		print ('Query No.%d of %d' % (idx, len(test_pairs)))

		features = []
		for path in named_paths:
			features.append(int(bfs_two(sample[0], sample[1], path, kb, kb_inv)))
		features = features * length_weights
		score_rl = sum(features)
		scores_rl.append(score_rl)

	rank_stats_rl = list(zip(scores_rl, test_labels))
	rank_stats_rl.sort(key = lambda x:x[0], reverse=True)

	correct = 0
	ranks = []
	for idx, item in enumerate(rank_stats_rl):
		if item[1] == 1:
			correct += 1
			ranks.append(correct/(1.0+idx))
	ap3 = np.mean(ranks)
	# print(len(ranks))
	print ('RL: ', ap3)

	with open("logs/fact_prediction/" + relation + ".out", 'a') as fw:
		fw.write(filename + '\n')
		fw.write('RL fact prediction: ' + str(ap3) + '\n')
		fw.write("\n")
Exemplo n.º 4
0
def evaluate_logic():
    kb = KB()
    kb_inv = KB()

    f = open(dataPath_ + '/graph.txt')
    kb_lines = f.readlines()
    f.close()

    for line in kb_lines:
        e1 = line.split()[0]
        rel = line.split()[1]
        e2 = line.split()[2]
        kb.addRelation(e1, rel, e2)
        kb_inv.addRelation(e2, rel, e1)

    _, named_paths = get_features()

    model = train(kb, kb_inv, named_paths)

    f = open(dataPath_ + '/sort_test.pairs')
    test_data = f.readlines()
    f.close()
    test_pairs = []
    test_labels = []
    # queries = set()
    for line in test_data:
        e1 = line.split(',')[0].replace('thing$', '')
        # e1 = '/' + e1[0] + '/' + e1[2:]
        e2 = line.split(',')[1].split(':')[0].replace('thing$', '')
        # e2 = '/' + e2[0] + '/' + e2[2:]
        if (e1 not in kb.entities) or (e2 not in kb.entities):
            continue
        test_pairs.append((e1, e2))
        label = 1 if line[-2] == '+' else 0
        test_labels.append(label)

    aps = []
    query = test_pairs[0][0]
    y_true = []
    y_score = []

    score_all = []

    for idx, sample in enumerate(test_pairs):
        #print 'query node: ', sample[0], idx
        if sample[0] == query:
            features = []
            for path in named_paths:
                features.append(
                    int(bfs_two(sample[0], sample[1], path, kb, kb_inv)))

            #features = features*path_weights

            score = model.predict(np.reshape(features, [1, -1]))
            #score = np.sum(features)

            score_all.append(score[0])
            y_score.append(score)
            y_true.append(test_labels[idx])
        else:
            query = sample[0]
            count = zip(y_score, y_true)
            count.sort(key=lambda x: x[0], reverse=True)
            ranks = []
            correct = 0
            for idx_, item in enumerate(count):
                if item[1] == 1:
                    correct += 1
                    ranks.append(correct / (1.0 + idx_))
                    #break
            if len(ranks) == 0:
                aps.append(0)
            else:
                aps.append(np.mean(ranks))
            #print np.mean(ranks)
            # if len(aps) % 10 == 0:
            # 	print 'How many queries:', len(aps)
            # 	print np.mean(aps)
            y_true = []
            y_score = []
            features = []
            for path in named_paths:
                features.append(
                    int(bfs_two(sample[0], sample[1], path, kb, kb_inv)))

            #features = features*path_weights
            #score = np.inner(features, path_weights)
            #score = np.sum(features)
            score = model.predict(np.reshape(features, [1, -1]))

            score_all.append(score[0])
            y_score.append(score)
            y_true.append(test_labels[idx])
            # print y_score, y_true

    count = zip(y_score, y_true)
    count.sort(key=lambda x: x[0], reverse=True)
    ranks = []
    correct = 0
    for idx_, item in enumerate(count):
        if item[1] == 1:
            correct += 1
            ranks.append(correct / (1.0 + idx_))
    aps.append(np.mean(ranks))

    score_label = zip(score_all, test_labels)
    score_label_ranked = sorted(score_label, key=lambda x: x[0], reverse=True)

    mean_ap = np.mean(aps)
    print 'RL MAP: ', mean_ap
Exemplo n.º 5
0
def REINFORCE(training_pairs, policy_nn, optimizer, num_episodes, relation=None):
	f = open(graphpath)
	content = f.readlines()
	f.close()
	kb = KB()
	for line in content:
		ent1, rel, ent2 = line.rsplit()
		kb.addRelation(ent1, rel, ent2) # Each line is a triple, represented with strings instead of numbers
		
	dropout = nn.Dropout(dynamic_action_dropout_rate)

	train = training_pairs

	success = 0

	path_found = set()
	path_found_entity = []
	path_relation_found = []
	success_cnt_list = []

	env = Env(dataPath, train[0], model=args.model)
	# Initialize the environment

	for i_episode in range(num_episodes):
	# for i_episode in range(15):
		start = time.time()
		print ('Episode %d' % i_episode)
		sample = train[random.choice(range(len(training_pairs)))]
		print ('Training sample: ', sample[:-1])

		if relation is None:
			env = Env(dataPath, sample, args.model)
		else:
			env.path = []
			env.path_relations = []

		sample = sample.split()
		state_idx = [env.entity2id_[sample[0]], env.entity2id_[sample[1]], 0]

		episode = []

		state_batch_negative = []
		lstm_input_batch_negative = []
		hidden_batch_negative = []
		cell_batch_negative = []
		action_batch_negative = []
		now_embedding_batch_negative = []
		neighbour_embeddings_list_batch_negative = []

		state_batch_positive = []
		lstm_input_batch_positive = []
		hidden_batch_positive = []
		cell_batch_positive = []
		action_batch_positive = []
		now_embedding_batch_positive = []
		neighbour_embeddings_list_batch_positive = []

		hidden_this_time = torch.zeros(3, 1, hidden_dim)
		cell_this_time = torch.zeros(3, 1, hidden_dim)
		if USE_CUDA:
			hidden_this_time = hidden_this_time.cuda()
			cell_this_time = cell_this_time.cuda()

		forward_node_list = []

		for t in count():
		# for t in range(10):
			state_vec = floatTensor(env.idx_state(state_idx))
			state = torch.cat([state_vec, hidden_this_time[-1]], dim=1) # Only use the last layer's output
			lstm_input = state_vec.unsqueeze(1)

			now_embedding = floatTensor(env.entity2vec[[state_idx[0]]])

			connected_node_list = []
			if state_idx[0] in env.entity2link:
				for rel in env.entity2link[state_idx[0]]:
					connected_node_list.extend(env.entity2link[state_idx[0]][rel])
			connected_node_list = list(set(connected_node_list))
			if len(connected_node_list) == 0:
				neighbour_embeddings_list = [torch.zeros(1, embedding_dim).cuda() if USE_CUDA else torch.zeros(1, embedding_dim)]
			else:
				neighbour_embeddings_list = [floatTensor(env.entity2vec[connected_node_list])]

			action_probs, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden_this_time, cell_this_time, now_embedding, neighbour_embeddings_list)

			# Action Dropout
			dropout_action_probs = dropout(action_probs)
			# print(dropout_action_probs.shape)
			probability = np.squeeze(dropout_action_probs.cpu().detach().numpy())
			probability = probability / sum(probability)
			action_chosen = np.random.choice(np.arange(action_space), p = probability)

			reward, new_state, done = env.interact(state_idx, action_chosen)
			
			if reward == -1: # the action fails for this step
				state_batch_negative.append(state)
				lstm_input_batch_negative.append(lstm_input)
				hidden_batch_negative.append(hidden_this_time)
				cell_batch_negative.append(cell_this_time)
				action_batch_negative.append(action_chosen)
				now_embedding_batch_negative.append(now_embedding)
				neighbour_embeddings_list_batch_negative.append(neighbour_embeddings_list[0])

				# Force to choose a valid action to go forward
				try:
					valid_action_list = list(env.entity2link[state_idx[0]].keys()) 
					probability = probability[valid_action_list]
					# print("Line 288: ", sum(probability))
					probability = probability / sum(probability)
					# print("Line 288: ", probability)
					valid_action_chosen = np.random.choice(valid_action_list, p = probability)
					valid_reward, valid_new_state, valid_done = env.interact(state_idx, valid_action_chosen)

					reward, new_state, done = valid_reward, valid_new_state, valid_done

					if new_state == None:
						forward_node_list.append(env.entity2id_[sample[1]]) # The right tail entity
					else:
						forward_node_list.append(new_state[0])

					state_batch_positive.append(state)
					lstm_input_batch_positive.append(lstm_input)
					hidden_batch_positive.append(hidden_this_time)
					cell_batch_positive.append(cell_this_time)
					action_batch_positive.append(valid_action_chosen)
					now_embedding_batch_positive.append(now_embedding)
					neighbour_embeddings_list_batch_positive.append(neighbour_embeddings_list[0])

					hidden_this_time = hidden_new
					cell_this_time = cell_new

				except:
					print("Cannot find a valid action!")

			else: # the action find a path that can forward
				if new_state == None:
					forward_node_list.append(env.entity2id_[sample[1]]) # The right tail entity
				else:
					forward_node_list.append(new_state[0])

				state_batch_positive.append(state)
				lstm_input_batch_positive.append(lstm_input)
				hidden_batch_positive.append(hidden_this_time)
				cell_batch_positive.append(cell_this_time)
				action_batch_positive.append(action_chosen)
				now_embedding_batch_positive.append(now_embedding)
				neighbour_embeddings_list_batch_positive.append(neighbour_embeddings_list[0])

				hidden_this_time = hidden_new
				cell_this_time = cell_new

			new_state_vec = env.idx_state(new_state)
			episode.append(Transition(state = state_vec, action = action_chosen, next_state = new_state_vec, reward = reward))

			if done or t == max_steps:
				break

			state_idx = new_state
			
		# Discourage the agent when it chooses an invalid step
		if len(state_batch_negative) != 0 and done != 1:
			print ('Penalty to invalid steps:', len(state_batch_negative))
			
			policy_nn.zero_grad()
			action_mask = byteTensor(convert_to_one_hot(np.array(action_batch_negative), depth = action_space))
			# action_prob = torch.stack(action_prob_batch_negative).squeeze(1)
			# print(state_batch_negative[0].shape)
			state = torch.cat(state_batch_negative, dim=0)
			lstm_input = torch.cat(lstm_input_batch_negative, dim=1)
			hidden = torch.cat(hidden_batch_negative, dim=1)
			cell = torch.cat(cell_batch_negative, dim=1)
			now_embedding = torch.cat(now_embedding_batch_negative, dim=0)
			action_prob, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden, cell, now_embedding, neighbour_embeddings_list_batch_negative)
			# print(action_prob.shape)
			picked_action_prob = torch.masked_select(action_prob, action_mask)
			print(picked_action_prob)
			loss = -torch.sum(torch.log(picked_action_prob) * args.wrong_reward) # Reward for each invalid action is wrong_reward
			loss.backward(retain_graph=True)
			torch.nn.utils.clip_grad_norm(policy_nn.parameters(), 0.2)
			optimizer.step()
			
		print ('----- FINAL PATH -----')
		print ('\t'.join(env.path))
		print ('PATH LENGTH', len(env.path))
		print ('----- FINAL PATH -----')
		
		# If the agent success, do one optimization
		if done == 1:
			print ('Success')
			
			path_found_entity.append(path_clean(' -> '.join(env.path)))

			success += 1

			# Compute the reward for a successful episode.
			path_length = len(env.path)
			length_reward = 1/path_length
			global_reward = 1

			if len(path_found) != 0:
				path_found_embedding = [env.path_embedding(path.split(' -> ')) for path in path_found]
				curr_path_embedding = env.path_embedding(env.path_relations)
				path_found_embedding = np.reshape(path_found_embedding, (-1,embedding_dim))
				cos_sim = cosine_similarity(path_found_embedding, curr_path_embedding)
				diverse_reward = -np.mean(cos_sim)
				print ('diverse_reward', diverse_reward)
				total_reward = args.global_reward_weight * global_reward + args.length_reward_weight * length_reward + args.diverse_reward_weight * diverse_reward 
			else:
				total_reward = args.global_reward_weight * global_reward + (args.length_reward_weight + args.diverse_reward_weight) * length_reward
			path_found.add(' -> '.join(env.path_relations))

			# total_reward = 0.1*global_reward + 0.9*length_reward
			

			policy_nn.zero_grad()
			action_mask = byteTensor(convert_to_one_hot(np.array(action_batch_positive), depth = action_space))
			state = torch.cat(state_batch_positive, dim=0)
			lstm_input = torch.cat(lstm_input_batch_positive, dim=1)
			hidden = torch.cat(hidden_batch_positive, dim=1)
			cell = torch.cat(cell_batch_positive, dim=1)
			now_embedding = torch.cat(now_embedding_batch_positive, dim=0)
			action_prob, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden, cell, now_embedding, neighbour_embeddings_list_batch_positive)
			# print(action_prob.shape)
			picked_action_prob = torch.masked_select(action_prob, action_mask)
			loss = -torch.sum(torch.log(picked_action_prob) * total_reward) 
			# The reward for each step of a successful episode is total_reward
			loss.backward(retain_graph=True)
			torch.nn.utils.clip_grad_norm(policy_nn.parameters(), 0.2)
			optimizer.step()
		else:

			if (len(state_batch_positive) != 0):
				# reward shaping

				if args.reward_shaping_model == "TransH":
					# print("Enters TransH.")
					head = ent_embedding[[env.entity2id_[sample[0]]]]
					rel_emb = rel_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					norm = norm_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					tail = ent_embedding[forward_node_list]
					head_proj = head - np.sum(head * norm, axis=1, keepdims=True) * norm
					tail_proj = tail - np.sum(tail * norm, axis=1, keepdims=True) * norm
					scores = -np.sum(np.abs(head_proj + rel_emb - tail_proj), axis = 1)
					# print(scores)

				elif args.reward_shaping_model == "TransR":
					# print("Enters TransR.")
					head = ent_embedding[[env.entity2id_[sample[0]]]]
					rel_emb = rel_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					norm = norm_embedding[[env.relation2id_[relation.replace('_', ':')]]].squeeze(0)
					tail = ent_embedding[forward_node_list]
					head_proj = np.matmul(norm, head.T).T
					tail_proj = np.matmul(norm, tail.T).T
					scores = -np.sum(np.abs(head_proj + rel_emb - tail_proj), axis = 1)
					# print(scores)

				elif args.reward_shaping_model == "TransD":
					# print("Enters TransD.")
					head = ent_embedding[[env.entity2id_[sample[0]]]]
					head_norm = ent_norm_embedding[[env.entity2id_[sample[0]]]]
					tail = ent_embedding[forward_node_list]
					tail_norm = ent_norm_embedding[forward_node_list]
					rel_emb = rel_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					rel_norm = rel_norm_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					head_proj = head + np.sum(head * head_norm, axis=1, keepdims=True) * rel_norm
					tail_proj = tail + np.sum(tail * tail_norm, axis=1, keepdims=True) * rel_norm
					scores = -np.sum(np.abs(head_proj + rel_emb - tail_proj), axis = 1)
					# print(scores)

				elif args.reward_shaping_model == "ProjE":
					# print("Enter ProjE.")
					h = ent_embedding[[env.entity2id_[sample[0]]]]
					r = rel_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					ent_mat = np.transpose(ent_embedding)
					hr = h * simple_hr_combination_weights[:100] + r * simple_hr_combination_weights[100:]
					hrt_res = np.matmul(np.tanh(hr + combination_bias_hr), ent_mat)
					scores = hrt_res[0][forward_node_list]
					scores = torch.log(torch.sigmoid(torch.FloatTensor(scores))).numpy()
					# print(scores)

				elif args.reward_shaping_model == "ConvE":
					# print("Enters ConvE.")
					rel_id = TransE_to_ConvE_id_relation[env.relation2id_[relation.replace('_', ':')]]
					head_id = TransE_to_ConvE_id_entity[env.entity2id_[sample[0]]]
					tail_id = [TransE_to_ConvE_id_entity[elem] for elem in forward_node_list]

					bs = ConvE_model.batch_size
					x_middle, output = ConvE_model(longTensor([head_id] + [0] * (bs - 1)), longTensor([rel_id] * bs))

					scores = np.log(output[0][tail_id].detach().cpu().numpy() + 10 ** -30)
					# print(scores)

				else:
					head_embedding = ent_embedding[env.entity2id_[sample[0]]]
					query_embedding = rel_embedding[env.relation2id_[relation.replace('_', ':')]]
					tail_embedding = ent_embedding[forward_node_list]
					scores = -np.sum(np.abs(head_embedding + query_embedding - tail_embedding), axis = 1)

				policy_nn.zero_grad()
				action_mask = byteTensor(convert_to_one_hot(np.array(action_batch_positive), depth = action_space))
				state = torch.cat(state_batch_positive, dim=0)
				lstm_input = torch.cat(lstm_input_batch_positive, dim=1)
				hidden = torch.cat(hidden_batch_positive, dim=1)
				cell = torch.cat(cell_batch_positive, dim=1)
				now_embedding = torch.cat(now_embedding_batch_positive, dim=0)
				action_prob, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden, cell, now_embedding, neighbour_embeddings_list_batch_positive)
				# print(action_prob.shape)
				picked_action_prob = torch.masked_select(action_prob, action_mask)
				# print(picked_action_prob)
				loss = -torch.sum(torch.log(picked_action_prob) * floatTensor(scores) * args.useless_reward) 
				# The reward for each step of an unsuccessful episode is useless_reward
				loss.backward(retain_graph=True)
				torch.nn.utils.clip_grad_norm(policy_nn.parameters(), 0.2)
				optimizer.step()
			
			print ('Failed, Do one teacher guideline') # Force the agent to learn using a successful sample
			teacher_success_flag = False
			teacher_success_failed_times = 0
			while (not teacher_success_flag) and teacher_success_failed_times < 3:
				try:
					good_episodes = teacher(sample[0], sample[1], 1, env, graphpath, knowledge_base = kb, output_mode = 1) # Episode's ID instead of state!
					if len(good_episodes) == 0:
						teacher_success_failed_times += 1
					else:
						for item in good_episodes:
							if len(item) == 0:
								teacher_success_failed_times += 1
								break

							teacher_state_batch = []
							teacher_action_batch = []
							teacher_now_embedding_batch = []
							teacher_neighbour_embeddings_list_batch = []

							total_reward = 0.0*1 + 1*1/len(item)

							for t, transition in enumerate(item):
								teacher_state_batch.append(floatTensor(env.idx_state(transition.state)))
								teacher_action_batch.append(transition.action)
								teacher_now_embedding_batch.append(floatTensor(env.entity2vec[[transition.state[0]]]))

								connected_node_list = []
								if transition.state[0] in env.entity2link:
									for rel in env.entity2link[transition.state[0]]:
										connected_node_list.extend(env.entity2link[transition.state[0]][rel])
								connected_node_list = list(set(connected_node_list)) # Remove duplicates
								if len(connected_node_list) == 0:
									if USE_CUDA:
										neighbour_embeddings_list = torch.zeros(1, embedding_dim).cuda()
									else:
										neighbour_embeddings_list = torch.zeros(1, embedding_dim)

								else:
									neighbour_embeddings_list = floatTensor(env.entity2vec[connected_node_list])

								teacher_neighbour_embeddings_list_batch.append(neighbour_embeddings_list)
							   
							if (len(teacher_state_batch) != 0):
								hidden_this_time = torch.zeros(3, 1, hidden_dim)
								cell_this_time = torch.zeros(3, 1, hidden_dim)
								if USE_CUDA:
									hidden_this_time = hidden_this_time.cuda()
									cell_this_time = cell_this_time.cuda()

								state_batch_teacher = []
								lstm_input_batch_teacher = []
								hidden_batch_teacher = []
								cell_batch_teacher = []

								for idx, state_vec in enumerate(teacher_state_batch):
									state_vec = floatTensor(state_vec)
									state = torch.cat([state_vec, hidden_this_time[-1]], dim=1) # Only use the last layer's output
									lstm_input = state_vec.unsqueeze(1)
									now_embedding = teacher_now_embedding_batch[idx]
									teacher_neighbour_embeddings_list = [teacher_neighbour_embeddings_list_batch[idx]]
									action_prob, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden_this_time, cell_this_time, now_embedding, teacher_neighbour_embeddings_list)
									# print(action_prob.shape)
									hidden_this_time = hidden_new
									cell_this_time = cell_new

									state_batch_teacher.append(state)
									lstm_input_batch_teacher.append(lstm_input)
									hidden_batch_teacher.append(hidden_this_time)
									cell_batch_teacher.append(cell_this_time)

								now_embedding = torch.cat(teacher_now_embedding_batch, dim=0)

								policy_nn.zero_grad()
								action_mask = byteTensor(convert_to_one_hot(np.array(teacher_action_batch), depth = action_space))
								state = torch.cat(state_batch_teacher, dim=0)
								lstm_input = torch.cat(lstm_input_batch_teacher, dim=1)
								hidden = torch.cat(hidden_batch_teacher, dim=1)
								cell = torch.cat(cell_batch_teacher, dim=1)
								action_prob, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden, cell, now_embedding, teacher_neighbour_embeddings_list_batch)
								# print(action_prob.shape)
								picked_action_prob = torch.masked_select(action_prob, action_mask)
								loss = -torch.sum(torch.log(picked_action_prob) * args.teacher_reward) # The reward for each step of a teacher episode is teacher_reward
								loss.backward(retain_graph=True)
								torch.nn.utils.clip_grad_norm(policy_nn.parameters(), 0.2)
								optimizer.step()

								teacher_success_flag = True
							else:
								teacher_success_failed_times += 1
					
				except Exception as e:
					print ('Teacher guideline failed')
					teacher_success_failed_times += 10

		print ('Episode time: ', time.time() - start)
		print ('\n')
		print ("Retrain Success count: ", success)
		success_cnt_list.append(success)
	print ('Retrain Success percentage:', success/num_episodes)
	print (success_cnt_list)
	
	for path in path_found_entity: # Only successful paths
		rel_ent = path.split(' -> ')
		path_relation = []
		for idx, item in enumerate(rel_ent):
			if idx%2 == 0:
				path_relation.append(item)
		path_relation_found.append(' -> '.join(path_relation))
		
	relation_path_stats = collections.Counter(path_relation_found).items()
	relation_path_stats = sorted(relation_path_stats, key = lambda x:x[1], reverse=True) # Rank the paths according to their frequency.
	
	f = open(feature_stats, 'w')
	for item in relation_path_stats:
		f.write(item[0]+'\t'+str(item[1])+'\n')
	f.close()
	print ('Path stats saved')

	with open("logs/training/" + relation + ".out", 'a') as fw:
		fw.write(save_file_header + '_path_stats.txt' + '\n')
		fw.write('Retrain Success persentage: ' + str(success/num_episodes) + '\n')
		fw.write("Retrain success cnt list: ")
		fw.write(" ".join([str(elem) for elem in success_cnt_list]) + '\n')
		fw.write("\n")

	return