Esempio n. 1
0
	def move(self,state):
	
		# [cur_e,cur_r,next_e,target_e]
		cur_e,cur_r,next_e,target_e=state[0],state[1],state[2],state[3]
		if self.env.entid2name[cur_e]=='START':
			step=0
		else:
			step=1
		action_space=self.env.get_action_space(next_e)
		logger.MARK('moving>>>>>action_space:%s'%str(action_space))
		if len(action_space)>0:

			state_vec=self.env.get_state_vec(cur_e,cur_r,target_e)
			action_pro=self.move_predict(state_vec,mode=self.policymethod,step=step)
			
			action_space_pro=action_pro[action_space]
			action=torch.argmax(action_space_pro)
			
			action=action_space[action.cpu()]
			
			logger.MARK('moving>>>>>move action:%s'%str(self.env.relid2name[action]))
			if action==self.env.relOPid:
				logger.MARK('moving>>>>>moving to OP..')
				return next_e, self.env.relOPid, self.env.entOPid
			next_r=action
			next_next_e=self.env.choose_e_from_action(next_e,next_r)
			
			return next_e,next_r,next_next_e
		else:
			logger.MARK('moving>>>>>moving to OP..')
			return next_e,self.env.relOPid,self.env.entOPid
Esempio n. 2
0
 def trajectory(self, e1, target_e, max_len=4):
     Observ = self.env.get_initial_state(
         e1, target_e)  #[cur_e,cur_r,next_e,target_e]
     traj = [list(Observ)]
     success = False
     for i in range(max_len):
         cur_e, cur_r, next_e, target_e = Observ[0], Observ[1], Observ[
             2], Observ[3]
         logger.MARK('Observing:%s' % str(Observ))
         if next_e == target_e:
             success = True
             traj = self.get_reward(traj, success=success)
             break
         elif next_e != target_e and i == max_len - 1:
             traj.append([
                 next_e, self.env.relid2name['OP'],
                 self.env.entid2name['OP'], target_e
             ])
             #traj.append([next_e, self.env.relOPid, self.env.entOPid, target_e])
             traj = self.get_reward(traj, success=success)
             break
         elif cur_r == self.env.relOPid:
             traj = self.get_reward(traj, success=success)
             break
         else:
             next_e, next_r, next_next_e = self.move(Observ)
             Observ = [next_e, next_r, next_next_e, target_e]
             traj.append(Observ)
     assert len(traj) > 0
     return traj, success
Esempio n. 3
0
def valid(pairs, enviro, ag, batchsize, maxlen=None):

    valid_pairs = pairs
    random.shuffle(valid_pairs)
    try_count, batch_loss, ave_reward, ave_success = 0, 0, 0, 0

    for query in valid_pairs:
        try:
            e1, e2 = query[0], query[1]
            e1, e2 = enviro.entity2id[e1], enviro.entity2id[e2]
            with torch.no_grad():
                traj, success = ag.trajectory(e1, e2, max_len=maxlen)
            try_count += 1
        except KeyError:
            continue

        logger.MARK(enviro.traj_for_showing(traj))

        traj_loss = 0
        with torch.no_grad():
            ag.policy.zero_history()
            for i in traj:
                ave_reward += i[4]
                loss = ag.update_memory_policy(i)
                traj_loss += loss.cpu()
        if success:
            ave_success += 1
        batch_loss += traj_loss / len(traj)


    logger.info('|%d have been valided|Batch_loss:%.4f|Ave_reward:%.3f|Ave_success:%%%.2f|' % (\
    try_count, batch_loss * 100 / try_count, ave_reward / try_count, ave_success *100 / try_count))
Esempio n. 4
0
def generate_paths(Env_a,agent_a,pairs,save_path,maxlen):
    try_count, batch_loss, ave_reward, ave_success = 0, 0, 0, 0
    paths=[]
    for query in pairs:
        try:
            e1, e2 = query[0], query[1]
            e1, e2 = Env_a.entity2id[e1], Env_a.entity2id[e2]
            with torch.no_grad():
                traj, success = agent_a.trajectory(e1, e2, max_len=maxlen)
            try_count += 1
        except KeyError:
            continue

        if success:
            logger.MARK('Find paths on the test:' + Env_a.traj_for_showing(traj))
            L = Env_a.traj2list(traj)
            paths.append(L)
    with open(save_path,'w') as fin:
        wrt_str=['\t'.join(i) for i in paths]
        wrt_str='\n'.join(wrt_str)
        fin.write(wrt_str)
Esempio n. 5
0
def train(task_relation="<diedIn>", rootpath=None, epoch=5):
    datapath = {'type2id': rootpath + 'type2id.json', 'relation2id': rootpath + 'relation2id.json', \
         'graph': rootpath + 'graph.pkl', 'ent2type': rootpath + 'ent2type.json' ,\
         'entity2id': rootpath + 'entity2id.json'}
    Env_a = env(datapath)
    Env_a.init_relation_query_state(task_relation)
    batchsize = 20
    maxlen = 5
    po = Policy_memory(Env_a, 300, 100, Env_a.rel_num)
    Env_a.filter_query(maxlen, 5000)
    pairs = Env_a.filter_query
    random.shuffle(pairs)

    training_pairs = pairs
    test_pairs = pairs[:int(len(pairs) * 0.5)]
    reward_record = []
    success_record = []
    path_length = 0
    valid_paris = pairs[int(len(pairs) * 0.5):int(len(pairs) * 0.6)]
    print('Train pairs:', len(training_pairs))
    print('valid pairs:', len(valid_paris))
    print('Test pairs:', len(test_pairs))
    agent_a = agent(po, Env_a, policymethod='GRU')
    if global_device == 'cuda:0':
        po = po.cuda()

    try_count, batch_loss, ave_reward, ave_success = 0, 0, 0, 0
    opt = torch.optim.Adam(agent_a.parameters() + Env_a.parameters(), lr=0.001)
    for ep in range(epoch):
        opt.zero_grad()
        random.shuffle(training_pairs)
        for query in training_pairs:
            try:
                e1, e2 = query[0], query[1]
                e1, e2 = Env_a.entity2id[e1], Env_a.entity2id[e2]
                with torch.no_grad():
                    traj, success = agent_a.trajectory(e1, e2, max_len=maxlen)
                try_count += 1
            except KeyError:
                continue

            logger.MARK(Env_a.traj_for_showing(traj))

            traj_loss = 0
            po.zero_history()
            traj_reward = 0

            for i in traj:

                ave_reward += i[4]
                traj_reward += i[4]
                loss = agent_a.update_memory_policy(i)
                loss.backward()
                traj_loss += loss.cpu()
            if success:
                ave_success += 1
                path_length += len(traj) - 1
                success_record.append(1)
            else:
                success_record.append(0)
            reward_record.append(traj_reward)
            batch_loss += traj_loss / len(traj)
            if try_count % batchsize == 0 and try_count > 0:
                opt.step()
                opt.zero_grad()
                logger.info(
                    '|%d epoch|%d eposide|Batch_loss:%.4f|Ave_reward:%.3f|Ave_success:%%%.2f|ave path lenghth:%.2f|'
                    % (ep, try_count, batch_loss * 100 / batchsize,
                       ave_reward / batchsize, ave_success * 100 / batchsize,
                       path_length / ave_success))
                batch_loss, ave_reward, ave_success, path_length = 0, 0, 0, 0

            if try_count % (20 * batchsize) == 0 and try_count > 0:
                valid(valid_paris, Env_a, agent_a, batchsize, maxlen)

        generate_paths(Env_a, agent_a, test_pairs,
                       rootpath + task_relation + '.paths', maxlen)

    success = ave_smooth(success_record, 20)
    reward = ave_smooth(reward_record, 20)

    with open(rootpath + task_relation + 'sucess_record_without.txt',
              'w') as fin:
        wstr = '\n'.join([str(i) for i in success])
        fin.write(wstr)
    with open(rootpath + task_relation + 'reward_record_without.txt',
              'w') as fin:
        wstr = '\n'.join([str(i) for i in reward])
        fin.write(wstr)

    with open(rootpath + task_relation + 'test_positive_pairs', 'w') as fin:
        wstr = []
        for i in test_pairs:
            wstr.append(str(i[0] + '\t' + str(i[1])))
        wstr = '\n'.join(wstr)
        fin.write(wstr)
Esempio n. 6
0
def train(task_relation="<diedIn>",rootpath=None,epoch=5):
    datapath = {'type2id': rootpath + 'type2id.json', 'relation2id': rootpath + 'relation2id.json', \
         'graph': rootpath + 'graph.pkl', 'ent2type': rootpath + 'ent2type.json' ,\
         'entity2id': rootpath + 'entity2id.json'}
    Env_a = env(datapath)
    Env_a.init_relation_query_state(task_relation)
    batchsize=20
    maxlen=5
    po = Policy_memory(Env_a,300, 100, Env_a.rel_num)
    # Env_a.filter_query(maxlen,5000)
    # pairs = Env_a.filter_query
    # random.shuffle(pairs)
    # training_pairs=pairs
    # test_pairs=pairs[:int(len(pairs)*0.5)]
    # valid_paris=pairs[int(len(pairs)*0.5):int(len(pairs)*0.6)]
    train_path=rootpath+'/'+task_relation+'train_pairs'
    valid_path = rootpath + '/' + task_relation + 'valid_pairs'
    training_pairs=load_pair(train_path)
    valid_paris=load_pair(valid_path)

    print('Train pairs:',len(training_pairs))
    print('valid pairs:',len(valid_paris))
    #print('Test pairs:',len(test_pairs))
    agent_a = agent(po, Env_a,policymethod='GRU')
    if global_device=='cuda:0':
        po=po.cuda()

    try_count, batch_loss, ave_reward, ave_success = 0, 0, 0, 0
    opt=torch.optim.Adam(agent_a.parameters()+Env_a.parameters(),lr=0.001)
    for ep in range(epoch):
        opt.zero_grad()
        random.shuffle(training_pairs)
        for query in training_pairs:
            try:
                e1, e2 = query[0], query[1]
                e1, e2 = Env_a.entity2id[e1], Env_a.entity2id[e2]
                with torch.no_grad():
                    traj, success = agent_a.trajectory(e1, e2,max_len=maxlen)
                try_count += 1
            except KeyError:
                continue
            logger.MARK(Env_a.traj_for_showing(traj))
            traj_loss=0
            po.zero_history()
            traj_reward=0

            for i in traj:

                ave_reward+=i[4]
                traj_reward+=i[4]
                loss=agent_a.update_memory_policy(i)
                loss.backward()
                traj_loss+=loss.cpu()
            if success:
                ave_success+=1
            batch_loss+=traj_loss/len(traj)
            if try_count%batchsize==0 and try_count>0:
                opt.step()
                opt.zero_grad()
                logger.info('|%d epoch|%d eposide|Batch_loss:%.4f|Ave_reward:%.3f|Ave_success:%%%.2f|'%(ep,try_count,batch_loss*100/batchsize,ave_reward/batchsize,ave_success*100/batchsize))
                batch_loss,ave_reward,ave_success=0,0,0

            if try_count%(20*batchsize)==0 and try_count>0:
                valid(valid_paris,Env_a,agent_a,batchsize,maxlen)

        generate_paths(Env_a,agent_a,test_pairs,rootpath+task_relation+'.paths',maxlen)