def move(self,state): # [cur_e,cur_r,next_e,target_e] cur_e,cur_r,next_e,target_e=state[0],state[1],state[2],state[3] if self.env.entid2name[cur_e]=='START': step=0 else: step=1 action_space=self.env.get_action_space(next_e) logger.MARK('moving>>>>>action_space:%s'%str(action_space)) if len(action_space)>0: state_vec=self.env.get_state_vec(cur_e,cur_r,target_e) action_pro=self.move_predict(state_vec,mode=self.policymethod,step=step) action_space_pro=action_pro[action_space] action=torch.argmax(action_space_pro) action=action_space[action.cpu()] logger.MARK('moving>>>>>move action:%s'%str(self.env.relid2name[action])) if action==self.env.relOPid: logger.MARK('moving>>>>>moving to OP..') return next_e, self.env.relOPid, self.env.entOPid next_r=action next_next_e=self.env.choose_e_from_action(next_e,next_r) return next_e,next_r,next_next_e else: logger.MARK('moving>>>>>moving to OP..') return next_e,self.env.relOPid,self.env.entOPid
def trajectory(self, e1, target_e, max_len=4): Observ = self.env.get_initial_state( e1, target_e) #[cur_e,cur_r,next_e,target_e] traj = [list(Observ)] success = False for i in range(max_len): cur_e, cur_r, next_e, target_e = Observ[0], Observ[1], Observ[ 2], Observ[3] logger.MARK('Observing:%s' % str(Observ)) if next_e == target_e: success = True traj = self.get_reward(traj, success=success) break elif next_e != target_e and i == max_len - 1: traj.append([ next_e, self.env.relid2name['OP'], self.env.entid2name['OP'], target_e ]) #traj.append([next_e, self.env.relOPid, self.env.entOPid, target_e]) traj = self.get_reward(traj, success=success) break elif cur_r == self.env.relOPid: traj = self.get_reward(traj, success=success) break else: next_e, next_r, next_next_e = self.move(Observ) Observ = [next_e, next_r, next_next_e, target_e] traj.append(Observ) assert len(traj) > 0 return traj, success
def valid(pairs, enviro, ag, batchsize, maxlen=None): valid_pairs = pairs random.shuffle(valid_pairs) try_count, batch_loss, ave_reward, ave_success = 0, 0, 0, 0 for query in valid_pairs: try: e1, e2 = query[0], query[1] e1, e2 = enviro.entity2id[e1], enviro.entity2id[e2] with torch.no_grad(): traj, success = ag.trajectory(e1, e2, max_len=maxlen) try_count += 1 except KeyError: continue logger.MARK(enviro.traj_for_showing(traj)) traj_loss = 0 with torch.no_grad(): ag.policy.zero_history() for i in traj: ave_reward += i[4] loss = ag.update_memory_policy(i) traj_loss += loss.cpu() if success: ave_success += 1 batch_loss += traj_loss / len(traj) logger.info('|%d have been valided|Batch_loss:%.4f|Ave_reward:%.3f|Ave_success:%%%.2f|' % (\ try_count, batch_loss * 100 / try_count, ave_reward / try_count, ave_success *100 / try_count))
def generate_paths(Env_a,agent_a,pairs,save_path,maxlen): try_count, batch_loss, ave_reward, ave_success = 0, 0, 0, 0 paths=[] for query in pairs: try: e1, e2 = query[0], query[1] e1, e2 = Env_a.entity2id[e1], Env_a.entity2id[e2] with torch.no_grad(): traj, success = agent_a.trajectory(e1, e2, max_len=maxlen) try_count += 1 except KeyError: continue if success: logger.MARK('Find paths on the test:' + Env_a.traj_for_showing(traj)) L = Env_a.traj2list(traj) paths.append(L) with open(save_path,'w') as fin: wrt_str=['\t'.join(i) for i in paths] wrt_str='\n'.join(wrt_str) fin.write(wrt_str)
def train(task_relation="<diedIn>", rootpath=None, epoch=5): datapath = {'type2id': rootpath + 'type2id.json', 'relation2id': rootpath + 'relation2id.json', \ 'graph': rootpath + 'graph.pkl', 'ent2type': rootpath + 'ent2type.json' ,\ 'entity2id': rootpath + 'entity2id.json'} Env_a = env(datapath) Env_a.init_relation_query_state(task_relation) batchsize = 20 maxlen = 5 po = Policy_memory(Env_a, 300, 100, Env_a.rel_num) Env_a.filter_query(maxlen, 5000) pairs = Env_a.filter_query random.shuffle(pairs) training_pairs = pairs test_pairs = pairs[:int(len(pairs) * 0.5)] reward_record = [] success_record = [] path_length = 0 valid_paris = pairs[int(len(pairs) * 0.5):int(len(pairs) * 0.6)] print('Train pairs:', len(training_pairs)) print('valid pairs:', len(valid_paris)) print('Test pairs:', len(test_pairs)) agent_a = agent(po, Env_a, policymethod='GRU') if global_device == 'cuda:0': po = po.cuda() try_count, batch_loss, ave_reward, ave_success = 0, 0, 0, 0 opt = torch.optim.Adam(agent_a.parameters() + Env_a.parameters(), lr=0.001) for ep in range(epoch): opt.zero_grad() random.shuffle(training_pairs) for query in training_pairs: try: e1, e2 = query[0], query[1] e1, e2 = Env_a.entity2id[e1], Env_a.entity2id[e2] with torch.no_grad(): traj, success = agent_a.trajectory(e1, e2, max_len=maxlen) try_count += 1 except KeyError: continue logger.MARK(Env_a.traj_for_showing(traj)) traj_loss = 0 po.zero_history() traj_reward = 0 for i in traj: ave_reward += i[4] traj_reward += i[4] loss = agent_a.update_memory_policy(i) loss.backward() traj_loss += loss.cpu() if success: ave_success += 1 path_length += len(traj) - 1 success_record.append(1) else: success_record.append(0) reward_record.append(traj_reward) batch_loss += traj_loss / len(traj) if try_count % batchsize == 0 and try_count > 0: opt.step() opt.zero_grad() logger.info( '|%d epoch|%d eposide|Batch_loss:%.4f|Ave_reward:%.3f|Ave_success:%%%.2f|ave path lenghth:%.2f|' % (ep, try_count, batch_loss * 100 / batchsize, ave_reward / batchsize, ave_success * 100 / batchsize, path_length / ave_success)) batch_loss, ave_reward, ave_success, path_length = 0, 0, 0, 0 if try_count % (20 * batchsize) == 0 and try_count > 0: valid(valid_paris, Env_a, agent_a, batchsize, maxlen) generate_paths(Env_a, agent_a, test_pairs, rootpath + task_relation + '.paths', maxlen) success = ave_smooth(success_record, 20) reward = ave_smooth(reward_record, 20) with open(rootpath + task_relation + 'sucess_record_without.txt', 'w') as fin: wstr = '\n'.join([str(i) for i in success]) fin.write(wstr) with open(rootpath + task_relation + 'reward_record_without.txt', 'w') as fin: wstr = '\n'.join([str(i) for i in reward]) fin.write(wstr) with open(rootpath + task_relation + 'test_positive_pairs', 'w') as fin: wstr = [] for i in test_pairs: wstr.append(str(i[0] + '\t' + str(i[1]))) wstr = '\n'.join(wstr) fin.write(wstr)
def train(task_relation="<diedIn>",rootpath=None,epoch=5): datapath = {'type2id': rootpath + 'type2id.json', 'relation2id': rootpath + 'relation2id.json', \ 'graph': rootpath + 'graph.pkl', 'ent2type': rootpath + 'ent2type.json' ,\ 'entity2id': rootpath + 'entity2id.json'} Env_a = env(datapath) Env_a.init_relation_query_state(task_relation) batchsize=20 maxlen=5 po = Policy_memory(Env_a,300, 100, Env_a.rel_num) # Env_a.filter_query(maxlen,5000) # pairs = Env_a.filter_query # random.shuffle(pairs) # training_pairs=pairs # test_pairs=pairs[:int(len(pairs)*0.5)] # valid_paris=pairs[int(len(pairs)*0.5):int(len(pairs)*0.6)] train_path=rootpath+'/'+task_relation+'train_pairs' valid_path = rootpath + '/' + task_relation + 'valid_pairs' training_pairs=load_pair(train_path) valid_paris=load_pair(valid_path) print('Train pairs:',len(training_pairs)) print('valid pairs:',len(valid_paris)) #print('Test pairs:',len(test_pairs)) agent_a = agent(po, Env_a,policymethod='GRU') if global_device=='cuda:0': po=po.cuda() try_count, batch_loss, ave_reward, ave_success = 0, 0, 0, 0 opt=torch.optim.Adam(agent_a.parameters()+Env_a.parameters(),lr=0.001) for ep in range(epoch): opt.zero_grad() random.shuffle(training_pairs) for query in training_pairs: try: e1, e2 = query[0], query[1] e1, e2 = Env_a.entity2id[e1], Env_a.entity2id[e2] with torch.no_grad(): traj, success = agent_a.trajectory(e1, e2,max_len=maxlen) try_count += 1 except KeyError: continue logger.MARK(Env_a.traj_for_showing(traj)) traj_loss=0 po.zero_history() traj_reward=0 for i in traj: ave_reward+=i[4] traj_reward+=i[4] loss=agent_a.update_memory_policy(i) loss.backward() traj_loss+=loss.cpu() if success: ave_success+=1 batch_loss+=traj_loss/len(traj) if try_count%batchsize==0 and try_count>0: opt.step() opt.zero_grad() logger.info('|%d epoch|%d eposide|Batch_loss:%.4f|Ave_reward:%.3f|Ave_success:%%%.2f|'%(ep,try_count,batch_loss*100/batchsize,ave_reward/batchsize,ave_success*100/batchsize)) batch_loss,ave_reward,ave_success=0,0,0 if try_count%(20*batchsize)==0 and try_count>0: valid(valid_paris,Env_a,agent_a,batchsize,maxlen) generate_paths(Env_a,agent_a,test_pairs,rootpath+task_relation+'.paths',maxlen)