示例#1
0
 def test_greedy(self, graph, path=None):
     self.problem.g = to_cuda(graph)
     res = [self.problem.calc_S().item()]
     pb = self.problem
     if path is not None:
         path = os.path.abspath(os.path.join(os.getcwd())) + path
         vis_g(pb, name=path + str(0), topo='cut')
     R = []
     Reward = []
     for j in range(100):
         M = []
         actions = pb.get_legal_actions()
         for k in range(actions.shape[0]):
             _, r = pb.step(actions[k], state=dc(pb.g))
             M.append(r)
         if max(M) <= 0:
             break
         if path is not None:
             vis_g(pb, name=path + str(j + 1), topo='cut')
         posi = [x for x in M if x > 0]
         nega = [x for x in M if x <= 0]
         # print('posi reward ratio:', len(posi) / len(M))
         # print('posi reward avg:', sum(posi) / len(posi))
         # print('nega reward avg:', sum(nega) / len(nega))
         max_idx = torch.tensor(M).argmax().item()
         _, r = pb.step((actions[max_idx, 0].item(), actions[max_idx,
                                                             1].item()))
         R.append((actions[max_idx, 0].item(), actions[max_idx,
                                                       1].item(), r.item()))
         Reward.append(r.item())
         res.append(res[-1] - r.item())
     return QtableKey2state(
         state2QtableKey(
             pb.g.ndata['label'].argmax(dim=1).cpu().numpy())), R, res
示例#2
0
    def prune_actions(self, bg, recent_states):
        batch_size = bg.batch_size
        n = bg.ndata['label'].shape[0] // batch_size

        batch_legal_actions = self.problem.get_legal_actions(
            state=bg,
            action_type=self.action_type,
            action_dropout=self.action_dropout).cuda()
        num_actions = batch_legal_actions.shape[0] // batch_size
        forbid_action_mask = torch.zeros(batch_legal_actions.shape[0],
                                         1).cuda()

        for k in range(self.new_epi_batch_size, bg.batch_size):
            cur_state = state2QtableKey(
                bg.ndata['label'][k * n:(k + 1) * n, :].argmax(
                    dim=1).cpu().numpy())  # current state
            forbid_states = set(recent_states[k - self.new_epi_batch_size])
            candicate_actions = batch_legal_actions[k * num_actions:(k + 1) *
                                                    num_actions, :]
            for j in range(num_actions):
                cur_state_l = QtableKey2state(cur_state)
                cur_state_l[candicate_actions[j][0]] ^= cur_state_l[
                    candicate_actions[j][1]]
                cur_state_l[candicate_actions[j][1]] ^= cur_state_l[
                    candicate_actions[j][0]]
                cur_state_l[candicate_actions[j][0]] ^= cur_state_l[
                    candicate_actions[j][1]]
                if state2QtableKey(cur_state_l) in forbid_states:
                    forbid_action_mask[j + k * num_actions] += 1
        batch_legal_actions = batch_legal_actions * (
            1 - forbid_action_mask.int()).t().flatten().unsqueeze(1)

        return batch_legal_actions
示例#3
0
    def cmpt_optimal(self, graph, path=None):

        self.problem.g = to_cuda(graph)
        res = [self.problem.calc_S().item()]
        pb = self.problem

        S = []
        for j in range(280):
            pb.reset_label(QtableKey2state(self.all_states[j]))
            S.append(pb.calc_S())

        s1 = torch.tensor(S).argmin()
        res.append(S[s1].item())

        if path is not None:
            path = os.path.abspath(os.path.join(os.getcwd())) + path
            pb.reset_label(QtableKey2state(self.all_states[s1]))
            vis_g(pb, name=path, topo='cut')
        return QtableKey2state(self.all_states[s1]), res
示例#4
0
    def test_init_state(self, alg, aux, g_i, t, trial_num=10):

        init_state = dc(self.episodes[g_i].init_state)

        init_states_label = [
            QtableKey2state(key) for key in np.random.choice(
                self.all_states, trial_num, replace=False)
        ]

        evals = []
        for k in range(trial_num):

            print('trial:', k)
            self.problem.g = dc(init_state)

            self.problem.reset_label(label=init_states_label[k])
            print('init state:', init_states_label[k])
            S = self.problem.calc_S()

            bg = dgl.batch([self.problem.g])

            state_eval = self.get_state_eval_from_aux(bg, alg, aux)

            evals.append(state_eval.detach().item())
            print('state_eval:', state_eval)

            # print('init S:', S)

            for i in range(t):

                actions = self.problem.get_legal_actions(
                    state=self.problem.g, action_type=self.action_type)
                S_a_encoding, h1, h2, Q_sa = alg.forward(
                    to_cuda(self.problem.g),
                    actions.cuda(),
                    action_type=self.action_type,
                    gnn_step=3)
                _, r = self.problem.step(state=self.problem.g,
                                         action=actions[torch.argmax(Q_sa)],
                                         action_type=self.action_type)
                S -= r.item()

                # print(Q_sa.detach().cpu().numpy())
                # print('action index:', torch.argmax(Q_sa).detach().cpu().item())
                # print('action:', actions[torch.argmax(Q_sa)].detach().cpu().numpy())
                # print('reward:', r.item())
                # print('kcut S:', S)

            print(S)
        return init_states_label[torch.tensor(evals).argmax()]
示例#5
0
        state2QtableKey(x)
        for x in list(itertools.permutations([0, 0, 0, 1, 1, 1, 2, 2, 2], 9))
    ]))

select_problem = []

for i in range(20000):

    # # brute force
    problem.reset()
    res[0, i] = problem.calc_S().item()
    pb1 = dc(problem)
    pb1.g = to_cuda(pb1.g)
    S = []
    for j in range(280):
        pb1.reset_label(QtableKey2state(all_states[j]))
        S.append(pb1.calc_S())

    s1 = torch.tensor(S).argmin()
    pb1.reset_label(QtableKey2state(all_states[s1]))
    res[1, i] = S[s1]
    # print(s1)

    # greedy algorithm
    # problem.g = gg[i]
    pb2 = dc(problem)
    pb2.g = to_cuda(pb2.g)
    # pb2 = problem
    R = []
    Reward = []
    for j in range(100):
示例#6
0
    def run_test(self,
                 problem=None,
                 init_trial=10,
                 trial_num=1,
                 batch_size=100,
                 gnn_step=3,
                 episode_len=50,
                 explore_prob=0.1,
                 Temperature=1.0,
                 aux_model=None,
                 beta=0):
        self.episode_len = episode_len
        self.num_actions = self.problem.get_legal_actions(
            action_type=self.action_type).shape[0]
        if isinstance(self.problem.g, BatchedDGLGraph):
            self.num_actions //= self.problem.g.batch_size
        # if self.action_type == 'swap':
        #     self.num_actions = self.problem.N * (self.problem.N - self.problem.m) // 2 + 1
        # else:
        #     self.num_actions = self.problem.N * (self.problem.k - 1) + 1
        self.trial_num = trial_num
        self.batch_size = batch_size
        if problem is None:
            batch_size *= trial_num
            bg = self.problem.gen_batch_graph(batch_size=self.batch_size)
            self.bg = to_cuda(
                self.problem.gen_batch_graph(x=bg.ndata['x'].repeat(
                    trial_num, 1),
                                             batch_size=batch_size))
            gl = dgl.unbatch(self.bg)
            test_problem = self.problem
        else:
            if aux_model is not None and init_trial > 1:
                gg = []
                gl = dgl.unbatch(problem)
                for i in range(batch_size):
                    init_label_i = self.test_init_state(alg=self.alg,
                                                        aux=aux_model,
                                                        g_i=i,
                                                        t=episode_len,
                                                        trial_num=init_trial)
                    self.problem.g = gl[i]
                    self.problem.reset_label(label=init_label_i)
                    gg.append(dc(self.problem.g))
                self.bg = dgl.batch(gg)
                self.problem.g = self.bg
                test_problem = self.problem
            else:
                # for validation problem set
                self.bg = problem
                gl = dgl.unbatch(self.bg)
                self.problem.g = self.bg
                test_problem = self.problem

        self.action_mask = torch.tensor(
            range(0, self.num_actions * batch_size, self.num_actions)).cuda()

        self.S = [
            test_problem.calc_S(g=gl[i]).cpu() for i in range(batch_size)
        ]

        ep = [
            EpisodeHistory(gl[i],
                           max_episode_len=episode_len,
                           action_type='swap') for i in range(batch_size)
        ]
        self.end_of_episode = {}.fromkeys(range(batch_size), episode_len - 1)
        loop_start_position = [0] * self.batch_size
        loop_remain_index = set(range(self.batch_size))
        self.valid_states = [[ep[i].init_state]
                             for i in range(self.batch_size)]
        for i in tqdm(range(episode_len)):
            batch_legal_actions = test_problem.get_legal_actions(
                state=self.bg, action_type=self.action_type).cuda()

            # if self.forbid_revisit and i > 0:
            #
            #     previous_actions = torch.tensor([ep[k].action_seq[i-1][0] * (self.n ** 2) + ep[k].action_seq[i-1][1] for k in range(batch_size)])
            #     bla = (self.n ** 2) * batch_legal_actions.view(batch_size, -1, 2)[:, :, 0] + batch_legal_actions.view(batch_size, -1, 2)[:, :, 1]
            #     forbid_action_mask = (bla.t() == previous_actions.cuda())
            #     batch_legal_actions = batch_legal_actions * (1 - forbid_action_mask.int()).t().flatten().unsqueeze(1)

            forbid_action_mask = torch.zeros(batch_legal_actions.shape[0],
                                             1).cuda()
            if self.forbid_revisit:
                for k in range(batch_size):
                    cur_state = ep[k].enc_state_seq[
                        -1]  # current state for the k-th episode
                    forbid_states = set(
                        ep[k].enc_state_seq[-1 - self.forbid_revisit:-1])
                    candicate_actions = batch_legal_actions[
                        k * self.num_actions:(k + 1) * self.num_actions, :]
                    for j in range(self.num_actions):
                        cur_state_l = QtableKey2state(cur_state)
                        cur_state_l[candicate_actions[j][0]] ^= cur_state_l[
                            candicate_actions[j][1]]
                        cur_state_l[candicate_actions[j][1]] ^= cur_state_l[
                            candicate_actions[j][0]]
                        cur_state_l[candicate_actions[j][0]] ^= cur_state_l[
                            candicate_actions[j][1]]
                        if state2QtableKey(cur_state_l) in forbid_states:
                            forbid_action_mask[j + k * self.num_actions] += 1

            batch_legal_actions = batch_legal_actions * (
                1 - forbid_action_mask.int()).t().flatten().unsqueeze(1)

            if self.q_net == 'mlp':
                # bg1 = to_cuda(self.bg)
                # bg1.ndata['label'] = bg1.ndata['label'][:,[0,2,1]]
                # bg2 = to_cuda(self.bg)
                # bg2.ndata['label'] = bg2.ndata['label'][:,[1,0,2]]
                # bg3 = to_cuda(self.bg)
                # bg3.ndata['label'] = bg3.ndata['label'][:,[1,2,0]]
                # bg4 = to_cuda(self.bg)
                # bg4.ndata['label'] = bg4.ndata['label'][:,[2,0,1]]
                # bg5 = to_cuda(self.bg)
                # bg5.ndata['label'] = bg5.ndata['label'][:,[2,1,0]]

                # S_a_encoding, h1, h2, Q_sa = self.alg.forward(dgl.batch([self.bg, bg1, bg2, bg3, bg4, bg5]), torch.cat([batch_legal_actions]*6),
                #                                               action_type=self.action_type, gnn_step=gnn_step)

                S_a_encoding, h1, h2, Q_sa = self.alg.forward(
                    self.bg,
                    batch_legal_actions,
                    action_type=self.action_type,
                    gnn_step=gnn_step)
                # Q_sa = Q_sa.view(6, -1, self.num_actions).sum(dim=0)
            else:
                S_a_encoding, h1, h2, Q_sa = self.alg.forward_MHA(
                    self.bg,
                    batch_legal_actions,
                    action_type=self.action_type,
                    gnn_step=gnn_step)

            if beta > 0:
                self.problem.sample_episode *= self.num_actions
                self.problem.gen_step_batch_mask()

                new_bg, _ = self.problem.step_batch(
                    states=dgl.batch([self.bg] * self.num_actions),
                    action=batch_legal_actions.view(batch_size,
                                                    self.num_actions,
                                                    2).transpose(0, 1).reshape(
                                                        -1, 2),
                    action_type=self.action_type,
                    return_sub_reward=False)
                self.problem.sample_episode = self.problem.sample_episode // self.num_actions
                self.problem.gen_step_batch_mask()

                state_eval = self.get_state_eval_from_aux(
                    new_bg, self.alg, aux_model)

                state_eval = state_eval.view(-1, batch_size).t().reshape(
                    batch_size * self.num_actions)
                self.state_eval.append(
                    state_eval.view(batch_size, self.num_actions))

                Q_sa -= state_eval * beta

            terminate_episode = (Q_sa.view(-1, self.num_actions).max(
                dim=1).values < 0.).nonzero().flatten().cpu().numpy()

            for idx in terminate_episode:
                if self.end_of_episode[idx] == episode_len - 1:
                    self.end_of_episode[idx] = i

            best_actions = torch.multinomial(
                F.softmax(Q_sa.view(-1, self.num_actions) / Temperature),
                1).view(-1)

            chose_actions = torch.tensor([
                x if torch.rand(1) > explore_prob else torch.randint(
                    high=self.num_actions, size=(1, )).squeeze()
                for x in best_actions
            ]).cuda()
            chose_actions += self.action_mask

            actions = batch_legal_actions[chose_actions]

            new_states, rewards = self.problem.step_batch(
                states=self.bg, action=actions, action_type=self.action_type)
            R = [reward.item() for reward in rewards]

            # enc_states = [state2QtableKey(self.bg.ndata['label'][k * self.n: (k + 1) * self.n].argmax(dim=1).cpu().numpy()) for k in range(batch_size)]

            [
                ep[k].write(
                    action=actions[k, :],
                    action_idx=best_actions[k] - self.action_mask[k],
                    reward=R[k],
                    q_val=Q_sa.view(-1, self.num_actions)[k, :],
                    actions=batch_legal_actions.view(-1, self.num_actions,
                                                     2)[k, :, :],
                    state_enc=None  #  enc_states[k]
                    # , sub_reward=sub_rewards[:, k].cpu().numpy()
                    ,
                    sub_reward=None,
                    loop_start_position=loop_start_position[k])
                for k in range(batch_size)
            ]

            # # write valid sample episode for training aux model
            # tmp = set()
            # for k in loop_remain_index:
            #     if enc_states[k] in ep[k].enc_state_seq[-4: -1]:
            #         loop_start_position[k] = i
            #         tmp.add(k)
            #     else:
            #         label = ep[k].init_state.ndata['label'][ep[k].label_perm[-1]]
            #         l = label.argmax(dim=1)
            #         self.problem.g = dc(ep[k].init_state)
            #         self.problem.reset_label(label=l)
            #         self.valid_states[k].append(dc(self.problem.g))
            # loop_remain_index = loop_remain_index.difference(tmp)

        self.episodes = ep