Ejemplo n.º 1
0
    def prune_actions(self, bg, recent_states):
        batch_size = bg.batch_size
        n = bg.ndata['label'].shape[0] // batch_size

        batch_legal_actions = self.problem.get_legal_actions(
            state=bg,
            action_type=self.action_type,
            action_dropout=self.action_dropout).cuda()
        num_actions = batch_legal_actions.shape[0] // batch_size
        forbid_action_mask = torch.zeros(batch_legal_actions.shape[0],
                                         1).cuda()

        for k in range(self.new_epi_batch_size, bg.batch_size):
            cur_state = state2QtableKey(
                bg.ndata['label'][k * n:(k + 1) * n, :].argmax(
                    dim=1).cpu().numpy())  # current state
            forbid_states = set(recent_states[k - self.new_epi_batch_size])
            candicate_actions = batch_legal_actions[k * num_actions:(k + 1) *
                                                    num_actions, :]
            for j in range(num_actions):
                cur_state_l = QtableKey2state(cur_state)
                cur_state_l[candicate_actions[j][0]] ^= cur_state_l[
                    candicate_actions[j][1]]
                cur_state_l[candicate_actions[j][1]] ^= cur_state_l[
                    candicate_actions[j][0]]
                cur_state_l[candicate_actions[j][0]] ^= cur_state_l[
                    candicate_actions[j][1]]
                if state2QtableKey(cur_state_l) in forbid_states:
                    forbid_action_mask[j + k * num_actions] += 1
        batch_legal_actions = batch_legal_actions * (
            1 - forbid_action_mask.int()).t().flatten().unsqueeze(1)

        return batch_legal_actions
Ejemplo n.º 2
0
    def compare_opt(self, validation_problem):

        bingo = 0
        for i in range(self.batch_size):
            self.problem.g = self.episodes[i].init_state
            end_perm = self.episodes[i].label_perm[self.end_of_episode[i]]
            end_label = self.episodes[i].init_state.ndata['label'][end_perm]
            dqn_result = state2QtableKey(end_label.argmax(dim=1).cpu().numpy())
            opt_result = state2QtableKey(validation_problem[i][1])
            grd_result = state2QtableKey(validation_problem[i][2])
            if dqn_result == opt_result:
                bingo += 1
        print('Optimal solution hit percentage:', bingo)
        return bingo
Ejemplo n.º 3
0
 def test_greedy(self, graph, path=None):
     self.problem.g = to_cuda(graph)
     res = [self.problem.calc_S().item()]
     pb = self.problem
     if path is not None:
         path = os.path.abspath(os.path.join(os.getcwd())) + path
         vis_g(pb, name=path + str(0), topo='cut')
     R = []
     Reward = []
     for j in range(100):
         M = []
         actions = pb.get_legal_actions()
         for k in range(actions.shape[0]):
             _, r = pb.step(actions[k], state=dc(pb.g))
             M.append(r)
         if max(M) <= 0:
             break
         if path is not None:
             vis_g(pb, name=path + str(j + 1), topo='cut')
         posi = [x for x in M if x > 0]
         nega = [x for x in M if x <= 0]
         # print('posi reward ratio:', len(posi) / len(M))
         # print('posi reward avg:', sum(posi) / len(posi))
         # print('nega reward avg:', sum(nega) / len(nega))
         max_idx = torch.tensor(M).argmax().item()
         _, r = pb.step((actions[max_idx, 0].item(), actions[max_idx,
                                                             1].item()))
         R.append((actions[max_idx, 0].item(), actions[max_idx,
                                                       1].item(), r.item()))
         Reward.append(r.item())
         res.append(res[-1] - r.item())
     return QtableKey2state(
         state2QtableKey(
             pb.g.ndata['label'].argmax(dim=1).cpu().numpy())), R, res
Ejemplo n.º 4
0
    def __init__(self, alg, problem=None, q_net='mlp', forbid_revisit=False):
        if isinstance(alg, DQN):
            self.alg = alg.model
        else:
            self.alg = alg

        self.problem = problem
        if isinstance(problem.g, BatchedDGLGraph):
            self.n = problem.g.ndata['label'].shape[0] // problem.g.batch_size
        else:
            self.n = problem.g.ndata['label'].shape[0]
        # if isinstance(problem, BatchedDGLGraph):
        #     self.problem = problem
        # elif isinstance(problem, list):
        #     self.problem = dgl.batch(problem)
        # else:
        #     self.problem = False

        self.episodes = []
        self.S = []
        self.max_gain = []
        self.max_gain_budget = []
        self.max_gain_ratio = []
        # self.action_indices = DataFrame(range(27))
        self.q_net = q_net
        self.forbid_revisit = forbid_revisit

        self.all_states = list(
            set([
                state2QtableKey(x) for x in list(
                    itertools.permutations([0, 0, 0, 1, 1, 1, 2, 2, 2], 9))
            ]))

        self.state_eval = []
Ejemplo n.º 5
0
    def __init__(self, g, max_episode_len, action_type='swap'):
        self.action_type = action_type
        self.init_state = dc(g)

        self.n = g.number_of_nodes()
        self.max_episode_len = max_episode_len
        self.episode_len = 0
        self.action_seq = []
        self.action_indices = []
        self.reward_seq = []
        self.q_pred = []
        self.action_candidates = []
        self.enc_state_seq = []
        self.sub_reward_seq = []
        # self.calib_reward_seq = []
        if self.action_type == 'swap':
            self.label_perm = torch.tensor(range(self.n)).unsqueeze(0)
            self.enc_state_seq.append(
                state2QtableKey(self.init_state.ndata['label'].argmax(
                    dim=1).cpu().numpy()))
        if self.action_type == 'flip':
            self.label_perm = self.init_state.ndata['label'].nonzero(
            )[:, 1].unsqueeze(0)
        # self.visited_state = set([''.join([str(i.item()) for i in torch.tensor(range(self.n)).unsqueeze(0)[0]])])
        # self.node_visit_cnt = [0] * self.n
        self.best_gain_sofar = 0
        self.current_gain = 0
        self.loop_start_position = 0
Ejemplo n.º 6
0
from tqdm import tqdm
import pickle
from toy_models.Qiter import vis_g, state2QtableKey, QtableKey2state

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

problem = KCut_DGL(k=3,
                   m=3,
                   adjacent_reserve=5,
                   hidden_dim=16,
                   mode='complete')
res = np.zeros((3, 20000))
greedy_move_history = []
all_states = list(
    set([
        state2QtableKey(x)
        for x in list(itertools.permutations([0, 0, 0, 1, 1, 1, 2, 2, 2], 9))
    ]))

select_problem = []

for i in range(20000):

    # # brute force
    problem.reset()
    res[0, i] = problem.calc_S().item()
    pb1 = dc(problem)
    pb1.g = to_cuda(pb1.g)
    S = []
    for j in range(280):
        pb1.reset_label(QtableKey2state(all_states[j]))
Ejemplo n.º 7
0
label_history[0].argmax(dim=1)

a=[test.episodes[i].loop_start_position for  i in range(100)]
b=[len(test.valid_states[i]) for  i in range(100)]
sum(a)
sum(b)
test.state_eval[3][1,:]
self=test
bingo = []
sway2 = []
zero = []
for i in range(100):
    self.problem.g = self.episodes[i].init_state
    end_perm = self.episodes[i].label_perm[self.end_of_episode[i]]
    end_label = self.episodes[i].init_state.ndata['label'][end_perm]
    dqn_result = state2QtableKey(end_label.argmax(dim=1).cpu().numpy())
    opt_result = state2QtableKey(validation_problem1[i][1])
    grd_result = state2QtableKey(validation_problem1[i][2])
    if opt_result==dqn_result:
        bingo.append(i)

    a1 = self.episodes[i].action_seq[-1][0] * 10000 + self.episodes[i].action_seq[-1][1]
    a2 = self.episodes[i].action_seq[-2][0] * 10000 + self.episodes[i].action_seq[-2][1]
    if a1==a2 and a1 > 0:
        sway2.append(i)
    elif a1==0 and a2==0:
        zero.append(i)

# hard validation set
test.run_test(problem=to_cuda(bg_hard), trial_num=1, batch_size=100, gnn_step=gnn_step,
              episode_len=episode_len, explore_prob=0.0, Temperature=1e-8)
Ejemplo n.º 8
0
    def run_test(self,
                 problem=None,
                 init_trial=10,
                 trial_num=1,
                 batch_size=100,
                 gnn_step=3,
                 episode_len=50,
                 explore_prob=0.1,
                 Temperature=1.0,
                 aux_model=None,
                 beta=0):
        self.episode_len = episode_len
        self.num_actions = self.problem.get_legal_actions(
            action_type=self.action_type).shape[0]
        if isinstance(self.problem.g, BatchedDGLGraph):
            self.num_actions //= self.problem.g.batch_size
        # if self.action_type == 'swap':
        #     self.num_actions = self.problem.N * (self.problem.N - self.problem.m) // 2 + 1
        # else:
        #     self.num_actions = self.problem.N * (self.problem.k - 1) + 1
        self.trial_num = trial_num
        self.batch_size = batch_size
        if problem is None:
            batch_size *= trial_num
            bg = self.problem.gen_batch_graph(batch_size=self.batch_size)
            self.bg = to_cuda(
                self.problem.gen_batch_graph(x=bg.ndata['x'].repeat(
                    trial_num, 1),
                                             batch_size=batch_size))
            gl = dgl.unbatch(self.bg)
            test_problem = self.problem
        else:
            if aux_model is not None and init_trial > 1:
                gg = []
                gl = dgl.unbatch(problem)
                for i in range(batch_size):
                    init_label_i = self.test_init_state(alg=self.alg,
                                                        aux=aux_model,
                                                        g_i=i,
                                                        t=episode_len,
                                                        trial_num=init_trial)
                    self.problem.g = gl[i]
                    self.problem.reset_label(label=init_label_i)
                    gg.append(dc(self.problem.g))
                self.bg = dgl.batch(gg)
                self.problem.g = self.bg
                test_problem = self.problem
            else:
                # for validation problem set
                self.bg = problem
                gl = dgl.unbatch(self.bg)
                self.problem.g = self.bg
                test_problem = self.problem

        self.action_mask = torch.tensor(
            range(0, self.num_actions * batch_size, self.num_actions)).cuda()

        self.S = [
            test_problem.calc_S(g=gl[i]).cpu() for i in range(batch_size)
        ]

        ep = [
            EpisodeHistory(gl[i],
                           max_episode_len=episode_len,
                           action_type='swap') for i in range(batch_size)
        ]
        self.end_of_episode = {}.fromkeys(range(batch_size), episode_len - 1)
        loop_start_position = [0] * self.batch_size
        loop_remain_index = set(range(self.batch_size))
        self.valid_states = [[ep[i].init_state]
                             for i in range(self.batch_size)]
        for i in tqdm(range(episode_len)):
            batch_legal_actions = test_problem.get_legal_actions(
                state=self.bg, action_type=self.action_type).cuda()

            # if self.forbid_revisit and i > 0:
            #
            #     previous_actions = torch.tensor([ep[k].action_seq[i-1][0] * (self.n ** 2) + ep[k].action_seq[i-1][1] for k in range(batch_size)])
            #     bla = (self.n ** 2) * batch_legal_actions.view(batch_size, -1, 2)[:, :, 0] + batch_legal_actions.view(batch_size, -1, 2)[:, :, 1]
            #     forbid_action_mask = (bla.t() == previous_actions.cuda())
            #     batch_legal_actions = batch_legal_actions * (1 - forbid_action_mask.int()).t().flatten().unsqueeze(1)

            forbid_action_mask = torch.zeros(batch_legal_actions.shape[0],
                                             1).cuda()
            if self.forbid_revisit:
                for k in range(batch_size):
                    cur_state = ep[k].enc_state_seq[
                        -1]  # current state for the k-th episode
                    forbid_states = set(
                        ep[k].enc_state_seq[-1 - self.forbid_revisit:-1])
                    candicate_actions = batch_legal_actions[
                        k * self.num_actions:(k + 1) * self.num_actions, :]
                    for j in range(self.num_actions):
                        cur_state_l = QtableKey2state(cur_state)
                        cur_state_l[candicate_actions[j][0]] ^= cur_state_l[
                            candicate_actions[j][1]]
                        cur_state_l[candicate_actions[j][1]] ^= cur_state_l[
                            candicate_actions[j][0]]
                        cur_state_l[candicate_actions[j][0]] ^= cur_state_l[
                            candicate_actions[j][1]]
                        if state2QtableKey(cur_state_l) in forbid_states:
                            forbid_action_mask[j + k * self.num_actions] += 1

            batch_legal_actions = batch_legal_actions * (
                1 - forbid_action_mask.int()).t().flatten().unsqueeze(1)

            if self.q_net == 'mlp':
                # bg1 = to_cuda(self.bg)
                # bg1.ndata['label'] = bg1.ndata['label'][:,[0,2,1]]
                # bg2 = to_cuda(self.bg)
                # bg2.ndata['label'] = bg2.ndata['label'][:,[1,0,2]]
                # bg3 = to_cuda(self.bg)
                # bg3.ndata['label'] = bg3.ndata['label'][:,[1,2,0]]
                # bg4 = to_cuda(self.bg)
                # bg4.ndata['label'] = bg4.ndata['label'][:,[2,0,1]]
                # bg5 = to_cuda(self.bg)
                # bg5.ndata['label'] = bg5.ndata['label'][:,[2,1,0]]

                # S_a_encoding, h1, h2, Q_sa = self.alg.forward(dgl.batch([self.bg, bg1, bg2, bg3, bg4, bg5]), torch.cat([batch_legal_actions]*6),
                #                                               action_type=self.action_type, gnn_step=gnn_step)

                S_a_encoding, h1, h2, Q_sa = self.alg.forward(
                    self.bg,
                    batch_legal_actions,
                    action_type=self.action_type,
                    gnn_step=gnn_step)
                # Q_sa = Q_sa.view(6, -1, self.num_actions).sum(dim=0)
            else:
                S_a_encoding, h1, h2, Q_sa = self.alg.forward_MHA(
                    self.bg,
                    batch_legal_actions,
                    action_type=self.action_type,
                    gnn_step=gnn_step)

            if beta > 0:
                self.problem.sample_episode *= self.num_actions
                self.problem.gen_step_batch_mask()

                new_bg, _ = self.problem.step_batch(
                    states=dgl.batch([self.bg] * self.num_actions),
                    action=batch_legal_actions.view(batch_size,
                                                    self.num_actions,
                                                    2).transpose(0, 1).reshape(
                                                        -1, 2),
                    action_type=self.action_type,
                    return_sub_reward=False)
                self.problem.sample_episode = self.problem.sample_episode // self.num_actions
                self.problem.gen_step_batch_mask()

                state_eval = self.get_state_eval_from_aux(
                    new_bg, self.alg, aux_model)

                state_eval = state_eval.view(-1, batch_size).t().reshape(
                    batch_size * self.num_actions)
                self.state_eval.append(
                    state_eval.view(batch_size, self.num_actions))

                Q_sa -= state_eval * beta

            terminate_episode = (Q_sa.view(-1, self.num_actions).max(
                dim=1).values < 0.).nonzero().flatten().cpu().numpy()

            for idx in terminate_episode:
                if self.end_of_episode[idx] == episode_len - 1:
                    self.end_of_episode[idx] = i

            best_actions = torch.multinomial(
                F.softmax(Q_sa.view(-1, self.num_actions) / Temperature),
                1).view(-1)

            chose_actions = torch.tensor([
                x if torch.rand(1) > explore_prob else torch.randint(
                    high=self.num_actions, size=(1, )).squeeze()
                for x in best_actions
            ]).cuda()
            chose_actions += self.action_mask

            actions = batch_legal_actions[chose_actions]

            new_states, rewards = self.problem.step_batch(
                states=self.bg, action=actions, action_type=self.action_type)
            R = [reward.item() for reward in rewards]

            # enc_states = [state2QtableKey(self.bg.ndata['label'][k * self.n: (k + 1) * self.n].argmax(dim=1).cpu().numpy()) for k in range(batch_size)]

            [
                ep[k].write(
                    action=actions[k, :],
                    action_idx=best_actions[k] - self.action_mask[k],
                    reward=R[k],
                    q_val=Q_sa.view(-1, self.num_actions)[k, :],
                    actions=batch_legal_actions.view(-1, self.num_actions,
                                                     2)[k, :, :],
                    state_enc=None  #  enc_states[k]
                    # , sub_reward=sub_rewards[:, k].cpu().numpy()
                    ,
                    sub_reward=None,
                    loop_start_position=loop_start_position[k])
                for k in range(batch_size)
            ]

            # # write valid sample episode for training aux model
            # tmp = set()
            # for k in loop_remain_index:
            #     if enc_states[k] in ep[k].enc_state_seq[-4: -1]:
            #         loop_start_position[k] = i
            #         tmp.add(k)
            #     else:
            #         label = ep[k].init_state.ndata['label'][ep[k].label_perm[-1]]
            #         l = label.argmax(dim=1)
            #         self.problem.g = dc(ep[k].init_state)
            #         self.problem.reset_label(label=l)
            #         self.valid_states[k].append(dc(self.problem.g))
            # loop_remain_index = loop_remain_index.difference(tmp)

        self.episodes = ep
Ejemplo n.º 9
0
    def run_cascade_episode(self, action_type='swap', gnn_step=3):

        sum_r = 0
        new_graphs = [
            to_cuda(
                self.problem.gen_batch_graph(
                    batch_size=self.new_epi_batch_size))
        ]

        new_graphs.extend(
            list(
                itertools.chain(*[[
                    tpl.s1 for tpl in self.cascade_replay_buffer[i]
                    [-self.new_epi_batch_size:]
                ] for i in range(self.buf_epi_len - 1)])))
        bg = to_cuda(dgl.batch(new_graphs))

        batch_size = self.new_epi_batch_size * self.buf_epi_len
        num_actions = self.problem.get_legal_actions(
            action_type=action_type,
            action_dropout=self.action_dropout).shape[0]

        # recent_states = list(itertools.chain(
        #     *[[tpl.recent_states for tpl in self.cascade_replay_buffer[t][-self.new_epi_batch_size:]] for t in
        #       range(self.buf_epi_len - 1)]))
        #
        # batch_legal_actions = self.prune_actions(bg, recent_states)
        batch_legal_actions = self.problem.get_legal_actions(
            state=bg,
            action_type=self.action_type,
            action_dropout=self.action_dropout).cuda()

        # epsilon greedy strategy
        # TODO: multi-gpu parallelization
        _, _, _, Q_sa = self.model(dc(bg),
                                   batch_legal_actions,
                                   action_type=action_type,
                                   gnn_step=gnn_step)

        # TODO: can alter explore strength according to kcut_valueS
        kcut_valueS = self.problem.calc_batchS(bg=bg)
        chosen_actions = self.sample_actions_from_q(Q_sa,
                                                    num_actions,
                                                    batch_size,
                                                    Temperature=self.eps)
        actions = batch_legal_actions[chosen_actions]

        # update bg inplace and calculate batch rewards
        g0 = [g for g in dgl.unbatch(dc(bg))]  # current_state

        _, rewards, sub_rewards = self.problem.step_batch(
            states=bg, action=actions, return_sub_reward=True)
        g1 = [g for g in dgl.unbatch(dc(bg))]  # after_state
        # kcut_valueS_ = self.problem.calc_batchS(bg=bg)
        # print('assert=0:', kcut_valueS_ - kcut_valueS + rewards)
        # rewards_ = rewards.reshape(self.buf_epi_len, -1)
        # posi_r = rewards_ >= 0
        # nege_r = rewards_ < 0
        # scaled_rewards = (rewards_-torch.mean(rewards_, dim=1).t().unsqueeze(1)) / torch.std(rewards_, dim=1).t().unsqueeze(1)
        # scaled_rewards = scaled_rewards * posi_r + rewards_ * nege_r
        #
        # scaled_rewards = scaled_rewards.view(-1)

        # update episode best S

        best_S = torch.cat([kcut_valueS[:self.new_epi_batch_size].cpu() \
                    , torch.tensor([[np.min([kcut_valueS[j+t*self.new_epi_batch_size], self.cascade_replay_buffer[t - 1][j-self.new_epi_batch_size].best_S]) for j in range(self.new_epi_batch_size)] for t in range(1, self.buf_epi_len)]).flatten()])

        # update recent states
        recent_states = [[
            state2QtableKey(g.ndata['label'].argmax(dim=1).cpu().numpy())
        ] for g in g0[:self.new_epi_batch_size]]
        recent_states.extend(
            list(
                itertools.chain(*[[
                    self.cascade_replay_buffer[t - 1]
                    [j - self.new_epi_batch_size].recent_states + [
                        state2QtableKey(g0[j + t * self.new_epi_batch_size].
                                        ndata['label'].argmax(
                                            dim=1).cpu().numpy())
                    ] for j in range(self.new_epi_batch_size)
                ] for t in range(1, self.buf_epi_len)])))

        memory_len = 10
        recent_states = [state[-memory_len:] for state in recent_states]
        forbid_actions = []

        [
            self.cascade_replay_buffer[t].extend([
                sars(g0[j + t * self.new_epi_batch_size],
                     actions[j + t * self.new_epi_batch_size],
                     rewards[j + t * self.new_epi_batch_size],
                     sub_rewards[:, j + t * self.new_epi_batch_size],
                     g1[j + t * self.new_epi_batch_size],
                     kcut_valueS[j + t * self.new_epi_batch_size],
                     best_S[j + t * self.new_epi_batch_size], t,
                     recent_states[j + t * self.new_epi_batch_size],
                     forbid_actions) for j in range(self.new_epi_batch_size)
            ]) for t in range(self.buf_epi_len)
        ]

        self.cascade_buffer_kcut_value = torch.cat([
            self.cascade_buffer_kcut_value,
            torch.abs(kcut_valueS.detach().cpu().view(self.buf_epi_len,
                                                      self.new_epi_batch_size))
        ],
                                                   dim=1).detach()

        if self.priority_sampling:
            # compute prioritized weights
            batch_legal_actions = self.problem.get_legal_actions(
                state=bg,
                action_type=action_type,
                action_dropout=self.action_dropout).cuda()
            _, _, _, Q_sa_next = self.model(dc(bg),
                                            batch_legal_actions,
                                            action_type=action_type,
                                            gnn_step=gnn_step)

            delta = Q_sa[chosen_actions] - (
                rewards +
                self.gamma * Q_sa_next.view(-1, num_actions).max(dim=1).values)
            # delta = (Q_sa[chosen_actions] - (rewards + self.gamma * Q_sa_next.view(-1, num_actions).max(dim=1).values)) / torch.clamp(torch.abs(Q_sa[chosen_actions]),0.1)
            self.cascade_replay_buffer_weight = torch.cat([
                self.cascade_replay_buffer_weight,
                torch.abs(delta.detach().cpu().view(self.buf_epi_len,
                                                    self.new_epi_batch_size))
            ],
                                                          dim=1).detach()
            # print(self.cascade_replay_buffer_weight)
        R = [reward.item() for reward in rewards]
        sum_r += sum(R)

        self.log.add_item('tot_return', sum_r)

        return R