def test_greedy(self, graph, path=None): self.problem.g = to_cuda(graph) res = [self.problem.calc_S().item()] pb = self.problem if path is not None: path = os.path.abspath(os.path.join(os.getcwd())) + path vis_g(pb, name=path + str(0), topo='cut') R = [] Reward = [] for j in range(100): M = [] actions = pb.get_legal_actions() for k in range(actions.shape[0]): _, r = pb.step(actions[k], state=dc(pb.g)) M.append(r) if max(M) <= 0: break if path is not None: vis_g(pb, name=path + str(j + 1), topo='cut') posi = [x for x in M if x > 0] nega = [x for x in M if x <= 0] # print('posi reward ratio:', len(posi) / len(M)) # print('posi reward avg:', sum(posi) / len(posi)) # print('nega reward avg:', sum(nega) / len(nega)) max_idx = torch.tensor(M).argmax().item() _, r = pb.step((actions[max_idx, 0].item(), actions[max_idx, 1].item())) R.append((actions[max_idx, 0].item(), actions[max_idx, 1].item(), r.item())) Reward.append(r.item()) res.append(res[-1] - r.item()) return QtableKey2state( state2QtableKey( pb.g.ndata['label'].argmax(dim=1).cpu().numpy())), R, res
def prune_actions(self, bg, recent_states): batch_size = bg.batch_size n = bg.ndata['label'].shape[0] // batch_size batch_legal_actions = self.problem.get_legal_actions( state=bg, action_type=self.action_type, action_dropout=self.action_dropout).cuda() num_actions = batch_legal_actions.shape[0] // batch_size forbid_action_mask = torch.zeros(batch_legal_actions.shape[0], 1).cuda() for k in range(self.new_epi_batch_size, bg.batch_size): cur_state = state2QtableKey( bg.ndata['label'][k * n:(k + 1) * n, :].argmax( dim=1).cpu().numpy()) # current state forbid_states = set(recent_states[k - self.new_epi_batch_size]) candicate_actions = batch_legal_actions[k * num_actions:(k + 1) * num_actions, :] for j in range(num_actions): cur_state_l = QtableKey2state(cur_state) cur_state_l[candicate_actions[j][0]] ^= cur_state_l[ candicate_actions[j][1]] cur_state_l[candicate_actions[j][1]] ^= cur_state_l[ candicate_actions[j][0]] cur_state_l[candicate_actions[j][0]] ^= cur_state_l[ candicate_actions[j][1]] if state2QtableKey(cur_state_l) in forbid_states: forbid_action_mask[j + k * num_actions] += 1 batch_legal_actions = batch_legal_actions * ( 1 - forbid_action_mask.int()).t().flatten().unsqueeze(1) return batch_legal_actions
def cmpt_optimal(self, graph, path=None): self.problem.g = to_cuda(graph) res = [self.problem.calc_S().item()] pb = self.problem S = [] for j in range(280): pb.reset_label(QtableKey2state(self.all_states[j])) S.append(pb.calc_S()) s1 = torch.tensor(S).argmin() res.append(S[s1].item()) if path is not None: path = os.path.abspath(os.path.join(os.getcwd())) + path pb.reset_label(QtableKey2state(self.all_states[s1])) vis_g(pb, name=path, topo='cut') return QtableKey2state(self.all_states[s1]), res
def test_init_state(self, alg, aux, g_i, t, trial_num=10): init_state = dc(self.episodes[g_i].init_state) init_states_label = [ QtableKey2state(key) for key in np.random.choice( self.all_states, trial_num, replace=False) ] evals = [] for k in range(trial_num): print('trial:', k) self.problem.g = dc(init_state) self.problem.reset_label(label=init_states_label[k]) print('init state:', init_states_label[k]) S = self.problem.calc_S() bg = dgl.batch([self.problem.g]) state_eval = self.get_state_eval_from_aux(bg, alg, aux) evals.append(state_eval.detach().item()) print('state_eval:', state_eval) # print('init S:', S) for i in range(t): actions = self.problem.get_legal_actions( state=self.problem.g, action_type=self.action_type) S_a_encoding, h1, h2, Q_sa = alg.forward( to_cuda(self.problem.g), actions.cuda(), action_type=self.action_type, gnn_step=3) _, r = self.problem.step(state=self.problem.g, action=actions[torch.argmax(Q_sa)], action_type=self.action_type) S -= r.item() # print(Q_sa.detach().cpu().numpy()) # print('action index:', torch.argmax(Q_sa).detach().cpu().item()) # print('action:', actions[torch.argmax(Q_sa)].detach().cpu().numpy()) # print('reward:', r.item()) # print('kcut S:', S) print(S) return init_states_label[torch.tensor(evals).argmax()]
state2QtableKey(x) for x in list(itertools.permutations([0, 0, 0, 1, 1, 1, 2, 2, 2], 9)) ])) select_problem = [] for i in range(20000): # # brute force problem.reset() res[0, i] = problem.calc_S().item() pb1 = dc(problem) pb1.g = to_cuda(pb1.g) S = [] for j in range(280): pb1.reset_label(QtableKey2state(all_states[j])) S.append(pb1.calc_S()) s1 = torch.tensor(S).argmin() pb1.reset_label(QtableKey2state(all_states[s1])) res[1, i] = S[s1] # print(s1) # greedy algorithm # problem.g = gg[i] pb2 = dc(problem) pb2.g = to_cuda(pb2.g) # pb2 = problem R = [] Reward = [] for j in range(100):
def run_test(self, problem=None, init_trial=10, trial_num=1, batch_size=100, gnn_step=3, episode_len=50, explore_prob=0.1, Temperature=1.0, aux_model=None, beta=0): self.episode_len = episode_len self.num_actions = self.problem.get_legal_actions( action_type=self.action_type).shape[0] if isinstance(self.problem.g, BatchedDGLGraph): self.num_actions //= self.problem.g.batch_size # if self.action_type == 'swap': # self.num_actions = self.problem.N * (self.problem.N - self.problem.m) // 2 + 1 # else: # self.num_actions = self.problem.N * (self.problem.k - 1) + 1 self.trial_num = trial_num self.batch_size = batch_size if problem is None: batch_size *= trial_num bg = self.problem.gen_batch_graph(batch_size=self.batch_size) self.bg = to_cuda( self.problem.gen_batch_graph(x=bg.ndata['x'].repeat( trial_num, 1), batch_size=batch_size)) gl = dgl.unbatch(self.bg) test_problem = self.problem else: if aux_model is not None and init_trial > 1: gg = [] gl = dgl.unbatch(problem) for i in range(batch_size): init_label_i = self.test_init_state(alg=self.alg, aux=aux_model, g_i=i, t=episode_len, trial_num=init_trial) self.problem.g = gl[i] self.problem.reset_label(label=init_label_i) gg.append(dc(self.problem.g)) self.bg = dgl.batch(gg) self.problem.g = self.bg test_problem = self.problem else: # for validation problem set self.bg = problem gl = dgl.unbatch(self.bg) self.problem.g = self.bg test_problem = self.problem self.action_mask = torch.tensor( range(0, self.num_actions * batch_size, self.num_actions)).cuda() self.S = [ test_problem.calc_S(g=gl[i]).cpu() for i in range(batch_size) ] ep = [ EpisodeHistory(gl[i], max_episode_len=episode_len, action_type='swap') for i in range(batch_size) ] self.end_of_episode = {}.fromkeys(range(batch_size), episode_len - 1) loop_start_position = [0] * self.batch_size loop_remain_index = set(range(self.batch_size)) self.valid_states = [[ep[i].init_state] for i in range(self.batch_size)] for i in tqdm(range(episode_len)): batch_legal_actions = test_problem.get_legal_actions( state=self.bg, action_type=self.action_type).cuda() # if self.forbid_revisit and i > 0: # # previous_actions = torch.tensor([ep[k].action_seq[i-1][0] * (self.n ** 2) + ep[k].action_seq[i-1][1] for k in range(batch_size)]) # bla = (self.n ** 2) * batch_legal_actions.view(batch_size, -1, 2)[:, :, 0] + batch_legal_actions.view(batch_size, -1, 2)[:, :, 1] # forbid_action_mask = (bla.t() == previous_actions.cuda()) # batch_legal_actions = batch_legal_actions * (1 - forbid_action_mask.int()).t().flatten().unsqueeze(1) forbid_action_mask = torch.zeros(batch_legal_actions.shape[0], 1).cuda() if self.forbid_revisit: for k in range(batch_size): cur_state = ep[k].enc_state_seq[ -1] # current state for the k-th episode forbid_states = set( ep[k].enc_state_seq[-1 - self.forbid_revisit:-1]) candicate_actions = batch_legal_actions[ k * self.num_actions:(k + 1) * self.num_actions, :] for j in range(self.num_actions): cur_state_l = QtableKey2state(cur_state) cur_state_l[candicate_actions[j][0]] ^= cur_state_l[ candicate_actions[j][1]] cur_state_l[candicate_actions[j][1]] ^= cur_state_l[ candicate_actions[j][0]] cur_state_l[candicate_actions[j][0]] ^= cur_state_l[ candicate_actions[j][1]] if state2QtableKey(cur_state_l) in forbid_states: forbid_action_mask[j + k * self.num_actions] += 1 batch_legal_actions = batch_legal_actions * ( 1 - forbid_action_mask.int()).t().flatten().unsqueeze(1) if self.q_net == 'mlp': # bg1 = to_cuda(self.bg) # bg1.ndata['label'] = bg1.ndata['label'][:,[0,2,1]] # bg2 = to_cuda(self.bg) # bg2.ndata['label'] = bg2.ndata['label'][:,[1,0,2]] # bg3 = to_cuda(self.bg) # bg3.ndata['label'] = bg3.ndata['label'][:,[1,2,0]] # bg4 = to_cuda(self.bg) # bg4.ndata['label'] = bg4.ndata['label'][:,[2,0,1]] # bg5 = to_cuda(self.bg) # bg5.ndata['label'] = bg5.ndata['label'][:,[2,1,0]] # S_a_encoding, h1, h2, Q_sa = self.alg.forward(dgl.batch([self.bg, bg1, bg2, bg3, bg4, bg5]), torch.cat([batch_legal_actions]*6), # action_type=self.action_type, gnn_step=gnn_step) S_a_encoding, h1, h2, Q_sa = self.alg.forward( self.bg, batch_legal_actions, action_type=self.action_type, gnn_step=gnn_step) # Q_sa = Q_sa.view(6, -1, self.num_actions).sum(dim=0) else: S_a_encoding, h1, h2, Q_sa = self.alg.forward_MHA( self.bg, batch_legal_actions, action_type=self.action_type, gnn_step=gnn_step) if beta > 0: self.problem.sample_episode *= self.num_actions self.problem.gen_step_batch_mask() new_bg, _ = self.problem.step_batch( states=dgl.batch([self.bg] * self.num_actions), action=batch_legal_actions.view(batch_size, self.num_actions, 2).transpose(0, 1).reshape( -1, 2), action_type=self.action_type, return_sub_reward=False) self.problem.sample_episode = self.problem.sample_episode // self.num_actions self.problem.gen_step_batch_mask() state_eval = self.get_state_eval_from_aux( new_bg, self.alg, aux_model) state_eval = state_eval.view(-1, batch_size).t().reshape( batch_size * self.num_actions) self.state_eval.append( state_eval.view(batch_size, self.num_actions)) Q_sa -= state_eval * beta terminate_episode = (Q_sa.view(-1, self.num_actions).max( dim=1).values < 0.).nonzero().flatten().cpu().numpy() for idx in terminate_episode: if self.end_of_episode[idx] == episode_len - 1: self.end_of_episode[idx] = i best_actions = torch.multinomial( F.softmax(Q_sa.view(-1, self.num_actions) / Temperature), 1).view(-1) chose_actions = torch.tensor([ x if torch.rand(1) > explore_prob else torch.randint( high=self.num_actions, size=(1, )).squeeze() for x in best_actions ]).cuda() chose_actions += self.action_mask actions = batch_legal_actions[chose_actions] new_states, rewards = self.problem.step_batch( states=self.bg, action=actions, action_type=self.action_type) R = [reward.item() for reward in rewards] # enc_states = [state2QtableKey(self.bg.ndata['label'][k * self.n: (k + 1) * self.n].argmax(dim=1).cpu().numpy()) for k in range(batch_size)] [ ep[k].write( action=actions[k, :], action_idx=best_actions[k] - self.action_mask[k], reward=R[k], q_val=Q_sa.view(-1, self.num_actions)[k, :], actions=batch_legal_actions.view(-1, self.num_actions, 2)[k, :, :], state_enc=None # enc_states[k] # , sub_reward=sub_rewards[:, k].cpu().numpy() , sub_reward=None, loop_start_position=loop_start_position[k]) for k in range(batch_size) ] # # write valid sample episode for training aux model # tmp = set() # for k in loop_remain_index: # if enc_states[k] in ep[k].enc_state_seq[-4: -1]: # loop_start_position[k] = i # tmp.add(k) # else: # label = ep[k].init_state.ndata['label'][ep[k].label_perm[-1]] # l = label.argmax(dim=1) # self.problem.g = dc(ep[k].init_state) # self.problem.reset_label(label=l) # self.valid_states[k].append(dc(self.problem.g)) # loop_remain_index = loop_remain_index.difference(tmp) self.episodes = ep