コード例 #1
0
ファイル: rev_solver.py プロジェクト: chenjw259/gnn_rl
 def test_greedy(self, graph, path=None):
     self.problem.g = to_cuda(graph)
     res = [self.problem.calc_S().item()]
     pb = self.problem
     if path is not None:
         path = os.path.abspath(os.path.join(os.getcwd())) + path
         vis_g(pb, name=path + str(0), topo='cut')
     R = []
     Reward = []
     for j in range(100):
         M = []
         actions = pb.get_legal_actions()
         for k in range(actions.shape[0]):
             _, r = pb.step(actions[k], state=dc(pb.g))
             M.append(r)
         if max(M) <= 0:
             break
         if path is not None:
             vis_g(pb, name=path + str(j + 1), topo='cut')
         posi = [x for x in M if x > 0]
         nega = [x for x in M if x <= 0]
         # print('posi reward ratio:', len(posi) / len(M))
         # print('posi reward avg:', sum(posi) / len(posi))
         # print('nega reward avg:', sum(nega) / len(nega))
         max_idx = torch.tensor(M).argmax().item()
         _, r = pb.step((actions[max_idx, 0].item(), actions[max_idx,
                                                             1].item()))
         R.append((actions[max_idx, 0].item(), actions[max_idx,
                                                       1].item(), r.item()))
         Reward.append(r.item())
         res.append(res[-1] - r.item())
     return QtableKey2state(
         state2QtableKey(
             pb.g.ndata['label'].argmax(dim=1).cpu().numpy())), R, res
コード例 #2
0
ファイル: parse_sample.py プロジェクト: chenjw259/gnn_rl
def map_psar2g(psar, hidden_dim=16, rewire_edges=False):

    # TODO: if use cluster topology in gnn,
    #  then rewire_edges=True is required(would be time consuming)
    psar.p.reset_label(label=psar.s, calc_S=False, rewire_edges=rewire_edges)
    g = to_cuda(psar.p.g)
    g.ndata['h'] = torch.zeros((g.number_of_nodes(), hidden_dim)).cuda()
    return g
コード例 #3
0
ファイル: episode_stats.py プロジェクト: chenjw259/gnn_rl
    def test_init_state(self, alg, aux, g_i, t, trial_num=10):

        init_state = dc(self.episodes[g_i].init_state)

        init_states_label = [
            QtableKey2state(key) for key in np.random.choice(
                self.all_states, trial_num, replace=False)
        ]

        evals = []
        for k in range(trial_num):

            print('trial:', k)
            self.problem.g = dc(init_state)

            self.problem.reset_label(label=init_states_label[k])
            print('init state:', init_states_label[k])
            S = self.problem.calc_S()

            bg = dgl.batch([self.problem.g])

            state_eval = self.get_state_eval_from_aux(bg, alg, aux)

            evals.append(state_eval.detach().item())
            print('state_eval:', state_eval)

            # print('init S:', S)

            for i in range(t):

                actions = self.problem.get_legal_actions(
                    state=self.problem.g, action_type=self.action_type)
                S_a_encoding, h1, h2, Q_sa = alg.forward(
                    to_cuda(self.problem.g),
                    actions.cuda(),
                    action_type=self.action_type,
                    gnn_step=3)
                _, r = self.problem.step(state=self.problem.g,
                                         action=actions[torch.argmax(Q_sa)],
                                         action_type=self.action_type)
                S -= r.item()

                # print(Q_sa.detach().cpu().numpy())
                # print('action index:', torch.argmax(Q_sa).detach().cpu().item())
                # print('action:', actions[torch.argmax(Q_sa)].detach().cpu().numpy())
                # print('reward:', r.item())
                # print('kcut S:', S)

            print(S)
        return init_states_label[torch.tensor(evals).argmax()]
コード例 #4
0
ファイル: rev_solver.py プロジェクト: chenjw259/gnn_rl
    def cmpt_optimal(self, graph, path=None):

        self.problem.g = to_cuda(graph)
        res = [self.problem.calc_S().item()]
        pb = self.problem

        S = []
        for j in range(280):
            pb.reset_label(QtableKey2state(self.all_states[j]))
            S.append(pb.calc_S())

        s1 = torch.tensor(S).argmin()
        res.append(S[s1].item())

        if path is not None:
            path = os.path.abspath(os.path.join(os.getcwd())) + path
            pb.reset_label(QtableKey2state(self.all_states[s1]))
            vis_g(pb, name=path, topo='cut')
        return QtableKey2state(self.all_states[s1]), res
コード例 #5
0
ファイル: episode_stats.py プロジェクト: chenjw259/gnn_rl
    def test_dqn(self, alg, g_i, t, init_label=None, path=None):

        init_state = dc(self.episodes[g_i].init_state)
        label_history = init_state.ndata['label'][
            self.episodes[g_i].label_perm]

        self.problem.g = dc(init_state)
        if init_label is None:
            self.problem.reset_label(label=label_history[0].argmax(dim=1))
        else:
            self.problem.reset_label(label=init_label)
        S = self.problem.calc_S()
        print('init S:', S)
        if path is not None:
            path = os.path.abspath(os.path.join(os.getcwd())) + path
            vis_g(self.problem, name=path + str(0), topo='cut')

        for i in range(t):
            actions = self.problem.get_legal_actions(
                state=self.problem.g, action_type=self.action_type)
            S_a_encoding, h1, h2, Q_sa = alg.forward(
                to_cuda(self.problem.g),
                actions.cuda(),
                action_type=self.action_type,
                gnn_step=3)
            _, r = self.problem.step(state=self.problem.g,
                                     action=actions[torch.argmax(Q_sa)],
                                     action_type=self.action_type)

            print(Q_sa.detach().cpu().numpy())
            print('action index:', torch.argmax(Q_sa).detach().cpu().item())
            print('action:',
                  actions[torch.argmax(Q_sa)].detach().cpu().numpy())
            print('reward:', r.item())
            S -= r.item()
            print('kcut S:', S)
            if path is not None:
                vis_g(self.problem, name=path + str(i + 1), topo='cut')
コード例 #6
0
ファイル: run.py プロジェクト: chenjw259/gnn_rl
def run_dqn(alg):

    t = 0
    writer = SummaryWriter(log_folder + save_folder)
    for n_iter in tqdm(range(n_epoch)):

        T1 = time.time()

        if n_iter > len(eps) - 1:
            alg.eps = eps[-1]
        else:
            alg.eps = eps[n_iter]

        # if n_iter < 2500:
        #     alg.gamma = 0
        # elif n_iter < 5000:
        #     alg.gamma = 0.3
        # elif n_iter < 7500:
        #     alg.gamma = 0.5
        # elif n_iter < 10000:
        #     alg.gamma = 0.7
        # elif n_iter < 15000:
        #     alg.gamma = 0.9

        T11 = time.time()
        # TODO memory usage :: episode_len * num_episodes * hidden_dim
        log = alg.train_dqn(target_bg=target_bg
                            , epoch=n_iter
                            , batch_size=batch_size
                            , num_episodes=n_episode
                            , episode_len=episode_len
                            , gnn_step=gnn_step
                            , q_step=q_step
                            , rollout_step=rollout_step
                            , ddqn=ddqn)
        if n_iter % target_update_step == target_update_step - 1:
            alg.update_target_net()
        T22 = time.time()
        print('Epoch: {}. R: {}. Q error: {}. H: {}. T: {}'
              .format(n_iter
               , np.round(log.get_current('tot_return'), 2)
               , np.round(log.get_current('Q_error'), 3)
               , np.round(log.get_current('entropy'), 3)
               , np.round(T22 - T11, 3)))

        T2 = time.time()

        if n_iter % save_ckpt_step == save_ckpt_step - 1:
            with open(path + 'dqn_'+str(model_version + n_iter + 1), 'wb') as model_file:
                pickle.dump(alg.model, model_file)
            t += 1
            # with open(path + 'buffer_' + str(model_version + n_iter + 1), 'wb') as model_file:
            #     pickle.dump([alg.cascade_replay_buffer_weight, [[problem.calc_S(g=elem.s0) for elem in alg.cascade_replay_buffer[i]] for i in range(len(alg.cascade_replay_buffer))]]
            #                  , model_file)

        # validation

        # test summary
        if n_iter % 100 == 0:
            test = test_summary(alg=alg, problem=test_problem1, action_type=action_type, q_net=readout, forbid_revisit=0)

            test.run_test(problem=to_cuda(bg_hard), trial_num=1, batch_size=100, gnn_step=gnn_step,
                          episode_len=episode_len, explore_prob=0.0, Temperature=1e-8)
            epi_r0 = test.show_result()
            if run_validation_33:
                best_hit0 = test.compare_opt(validation_problem0)

            test.problem = test_problem0
            test.run_test(problem=to_cuda(bg_easy), trial_num=1, batch_size=test_episode, gnn_step=gnn_step,
                          episode_len=episode_len, explore_prob=0.0, Temperature=1e-8)
            epi_r1 = test.show_result()
            if run_validation_33:
                best_hit1 = test.compare_opt(validation_problem1)

            test.problem = test_problem2
            test.run_test(problem=to_cuda(bg_subopt), trial_num=1, batch_size=100, gnn_step=gnn_step,
                          episode_len=episode_len, explore_prob=0.0, Temperature=1e-8)
            epi_r2 = test.show_result()
            if run_validation_33:
                best_hit2 = test.compare_opt(validation_problem0)

        writer.add_scalar('Reward/Training Episode Reward', log.get_current('tot_return') / n_episode, n_iter)
        writer.add_scalar('Loss/Q-Loss', log.get_current('Q_error'), n_iter)
        writer.add_scalar('Reward/Validation Episode Reward - hard', epi_r0, n_iter)
        writer.add_scalar('Reward/Validation Episode Reward - easy', epi_r1, n_iter)
        writer.add_scalar('Reward/Validation Episode Reward - subopt', epi_r2, n_iter)
        if run_validation_33:
            writer.add_scalar('Reward/Validation Opt. hit percent - hard', best_hit0, n_iter)
            writer.add_scalar('Reward/Validation Opt. hit percent - easy', best_hit1, n_iter)
            writer.add_scalar('Reward/Validation Opt. hit percent - subopt', best_hit2, n_iter)
        writer.add_scalar('Time/Running Time per Epoch', T2 - T1, n_iter)
コード例 #7
0
ファイル: run.py プロジェクト: chenjw259/gnn_rl
else:
    mode = 'w'
with open(path + 'params', mode) as params_file:
    params_file.write(time.strftime("%Y--%m--%d %H:%M:%S", time.localtime()))
    params_file.write('\n------------------------------------\n')
    params_file.write(json.dumps(args))
    params_file.write('\n------------------------------------\n')


# load validation set
if run_validation_33:
    with open('/p/reinforcement/data/gnn_rl/model/test_data/3by3/0', 'rb') as valid_file:
        validation_problem0 = pickle.load(valid_file)
    with open('/p/reinforcement/data/gnn_rl/model/test_data/3by3/1', 'rb') as valid_file:
        validation_problem1 = pickle.load(valid_file)
    bg_hard = to_cuda(dgl.batch([p[0].g for p in validation_problem0[:test_episode]]))
    bg_easy = to_cuda(dgl.batch([p[0].g for p in validation_problem1[:test_episode]]))

    bg_subopt = []
    for i in range(test_episode):
        gi = to_cuda(validation_problem0[:test_episode][i][0].g)
        problem.reset_label(g=gi, label=validation_problem0[:test_episode][i][2])
        bg_subopt.append(gi)
    bg_subopt = dgl.batch(bg_subopt)

    for bg_ in [bg_hard, bg_easy, bg_subopt]:
        if ajr == 8:
            bg_.edata['e_type'][:, 0] = torch.ones(N * ajr * bg_.batch_size)
        _, _, square_dist_matrix = dgl.transform.knn_graph(bg_.ndata['x'].view(test_episode, N, -1), ajr+1, extend_info=True)
        square_dist_matrix = F.relu(square_dist_matrix, inplace=True)  # numerical error could result in NaN in sqrt. value
        bg_.ndata['adj'] = torch.sqrt(square_dist_matrix).view(bg_.number_of_nodes(), -1)
コード例 #8
0
ファイル: Canonical_solvers.py プロジェクト: chenjw259/gnn_rl
greedy_move_history = []
all_states = list(
    set([
        state2QtableKey(x)
        for x in list(itertools.permutations([0, 0, 0, 1, 1, 1, 2, 2, 2], 9))
    ]))

select_problem = []

for i in range(20000):

    # # brute force
    problem.reset()
    res[0, i] = problem.calc_S().item()
    pb1 = dc(problem)
    pb1.g = to_cuda(pb1.g)
    S = []
    for j in range(280):
        pb1.reset_label(QtableKey2state(all_states[j]))
        S.append(pb1.calc_S())

    s1 = torch.tensor(S).argmin()
    pb1.reset_label(QtableKey2state(all_states[s1]))
    res[1, i] = S[s1]
    # print(s1)

    # greedy algorithm
    # problem.g = gg[i]
    pb2 = dc(problem)
    pb2.g = to_cuda(pb2.g)
    # pb2 = problem
コード例 #9
0
        problem.reset()
        init_S.append(problem.calc_S())

        _, _, sq_dist_matrix = dgl.transform.knn_graph(problem.g.ndata['x'],
                                                       1,
                                                       extend_info=True)
        mat_5by6 = (2 - torch.sqrt(F.relu(
            sq_dist_matrix, inplace=True))[0]).numpy().astype('float64')
        m_path = os.path.abspath(os.path.join(
            os.getcwd())) + '/toy_models/ga_helpers/corr_mat/dqn_5by6.mat'
        dump_matrix(mat_5by6, m_path)

        # path_m = os.path.abspath(os.path.join(os.getcwd())) + '/ga_helpers/corr_mat/dqn_5by6.mat'
        # path_c = os.path.abspath(os.path.join(os.getcwd())) + '/ga_helpers/configs/default.json'
        best_solution, acc, best_fitness = run_ga(path_m, path_c)

        best_label = np.zeros(k * m).astype('int')
        for i in range(k):
            best_label[best_solution[i, :]] += i
        problem.g = to_cuda(problem.g)
        problem.reset_label(label=best_label)
        # print('Final S1:', problem.calc_S())
        # print('Final S2:', k * m * (m - 1) - k * m * (m - 1) / 2 * best_fitness)
        best_S.append(problem.calc_S())
        # path = os.path.abspath(os.path.join(os.getcwd())) + '/toy_models/figs/test1'
        # vis_g(problem, name=path, topo='cut')

    print('init S:', sum(init_S) / num_trial)
    print('best S:', sum(best_S) / num_trial)
    print('gain ratio:', (sum(init_S) - sum(best_S)) / sum(init_S))
コード例 #10
0
# bg_subopt = dgl.batch(bg_subopt)
# bg_opt = dgl.batch(bg_opt)
#
# if ajr == 8:
#     bg_hard.edata['e_type'][:, 0] = torch.ones(k * m * ajr * bg_hard.batch_size)
#     bg_easy.edata['e_type'][:, 0] = torch.ones(k * m * ajr * bg_easy.batch_size)
#     bg_subopt.edata['e_type'][:, 0] = torch.ones(k * m * ajr * bg_subopt.batch_size)
#     bg_opt.edata['e_type'][:, 0] = torch.ones(k * m * ajr * bg_opt.batch_size)
#

# random validation set
# why not generalise to larger graphs?
problem = KCut_DGL(k=k, m=3, adjacent_reserve=8, hidden_dim=h, mode=mode, sample_episode=sample_episode, graph_style='cluster')
test = test_summary(alg=alg, problem=problem, q_net=q_net, forbid_revisit=0)

bg = to_cuda(problem.gen_batch_graph(batch_size=batch_size, style='plain'))
test.run_test(problem=bg, trial_num=trial_num, batch_size=batch_size, gnn_step=gnn_step, episode_len=episode_len, explore_prob=explore_prob, Temperature=Temperature)
test.show_result()

# easy validation set
for beta in [0.1]:
    print('beta', beta)
    test.run_test(problem=to_cuda(bg_easy), init_trial=1, trial_num=1, batch_size=100, gnn_step=gnn_step,
                  episode_len=episode_len, explore_prob=0.0, Temperature=1e-8
                  , aux_model=None
                  , beta=0)
    epi_r1 = test.show_result()
    best_hit1 = test.compare_opt(validation_problem1)

j = 0
for i in range(100):
コード例 #11
0
ファイル: episode_stats.py プロジェクト: chenjw259/gnn_rl
    def run_test(self,
                 problem=None,
                 init_trial=10,
                 trial_num=1,
                 batch_size=100,
                 gnn_step=3,
                 episode_len=50,
                 explore_prob=0.1,
                 Temperature=1.0,
                 aux_model=None,
                 beta=0):
        self.episode_len = episode_len
        self.num_actions = self.problem.get_legal_actions(
            action_type=self.action_type).shape[0]
        if isinstance(self.problem.g, BatchedDGLGraph):
            self.num_actions //= self.problem.g.batch_size
        # if self.action_type == 'swap':
        #     self.num_actions = self.problem.N * (self.problem.N - self.problem.m) // 2 + 1
        # else:
        #     self.num_actions = self.problem.N * (self.problem.k - 1) + 1
        self.trial_num = trial_num
        self.batch_size = batch_size
        if problem is None:
            batch_size *= trial_num
            bg = self.problem.gen_batch_graph(batch_size=self.batch_size)
            self.bg = to_cuda(
                self.problem.gen_batch_graph(x=bg.ndata['x'].repeat(
                    trial_num, 1),
                                             batch_size=batch_size))
            gl = dgl.unbatch(self.bg)
            test_problem = self.problem
        else:
            if aux_model is not None and init_trial > 1:
                gg = []
                gl = dgl.unbatch(problem)
                for i in range(batch_size):
                    init_label_i = self.test_init_state(alg=self.alg,
                                                        aux=aux_model,
                                                        g_i=i,
                                                        t=episode_len,
                                                        trial_num=init_trial)
                    self.problem.g = gl[i]
                    self.problem.reset_label(label=init_label_i)
                    gg.append(dc(self.problem.g))
                self.bg = dgl.batch(gg)
                self.problem.g = self.bg
                test_problem = self.problem
            else:
                # for validation problem set
                self.bg = problem
                gl = dgl.unbatch(self.bg)
                self.problem.g = self.bg
                test_problem = self.problem

        self.action_mask = torch.tensor(
            range(0, self.num_actions * batch_size, self.num_actions)).cuda()

        self.S = [
            test_problem.calc_S(g=gl[i]).cpu() for i in range(batch_size)
        ]

        ep = [
            EpisodeHistory(gl[i],
                           max_episode_len=episode_len,
                           action_type='swap') for i in range(batch_size)
        ]
        self.end_of_episode = {}.fromkeys(range(batch_size), episode_len - 1)
        loop_start_position = [0] * self.batch_size
        loop_remain_index = set(range(self.batch_size))
        self.valid_states = [[ep[i].init_state]
                             for i in range(self.batch_size)]
        for i in tqdm(range(episode_len)):
            batch_legal_actions = test_problem.get_legal_actions(
                state=self.bg, action_type=self.action_type).cuda()

            # if self.forbid_revisit and i > 0:
            #
            #     previous_actions = torch.tensor([ep[k].action_seq[i-1][0] * (self.n ** 2) + ep[k].action_seq[i-1][1] for k in range(batch_size)])
            #     bla = (self.n ** 2) * batch_legal_actions.view(batch_size, -1, 2)[:, :, 0] + batch_legal_actions.view(batch_size, -1, 2)[:, :, 1]
            #     forbid_action_mask = (bla.t() == previous_actions.cuda())
            #     batch_legal_actions = batch_legal_actions * (1 - forbid_action_mask.int()).t().flatten().unsqueeze(1)

            forbid_action_mask = torch.zeros(batch_legal_actions.shape[0],
                                             1).cuda()
            if self.forbid_revisit:
                for k in range(batch_size):
                    cur_state = ep[k].enc_state_seq[
                        -1]  # current state for the k-th episode
                    forbid_states = set(
                        ep[k].enc_state_seq[-1 - self.forbid_revisit:-1])
                    candicate_actions = batch_legal_actions[
                        k * self.num_actions:(k + 1) * self.num_actions, :]
                    for j in range(self.num_actions):
                        cur_state_l = QtableKey2state(cur_state)
                        cur_state_l[candicate_actions[j][0]] ^= cur_state_l[
                            candicate_actions[j][1]]
                        cur_state_l[candicate_actions[j][1]] ^= cur_state_l[
                            candicate_actions[j][0]]
                        cur_state_l[candicate_actions[j][0]] ^= cur_state_l[
                            candicate_actions[j][1]]
                        if state2QtableKey(cur_state_l) in forbid_states:
                            forbid_action_mask[j + k * self.num_actions] += 1

            batch_legal_actions = batch_legal_actions * (
                1 - forbid_action_mask.int()).t().flatten().unsqueeze(1)

            if self.q_net == 'mlp':
                # bg1 = to_cuda(self.bg)
                # bg1.ndata['label'] = bg1.ndata['label'][:,[0,2,1]]
                # bg2 = to_cuda(self.bg)
                # bg2.ndata['label'] = bg2.ndata['label'][:,[1,0,2]]
                # bg3 = to_cuda(self.bg)
                # bg3.ndata['label'] = bg3.ndata['label'][:,[1,2,0]]
                # bg4 = to_cuda(self.bg)
                # bg4.ndata['label'] = bg4.ndata['label'][:,[2,0,1]]
                # bg5 = to_cuda(self.bg)
                # bg5.ndata['label'] = bg5.ndata['label'][:,[2,1,0]]

                # S_a_encoding, h1, h2, Q_sa = self.alg.forward(dgl.batch([self.bg, bg1, bg2, bg3, bg4, bg5]), torch.cat([batch_legal_actions]*6),
                #                                               action_type=self.action_type, gnn_step=gnn_step)

                S_a_encoding, h1, h2, Q_sa = self.alg.forward(
                    self.bg,
                    batch_legal_actions,
                    action_type=self.action_type,
                    gnn_step=gnn_step)
                # Q_sa = Q_sa.view(6, -1, self.num_actions).sum(dim=0)
            else:
                S_a_encoding, h1, h2, Q_sa = self.alg.forward_MHA(
                    self.bg,
                    batch_legal_actions,
                    action_type=self.action_type,
                    gnn_step=gnn_step)

            if beta > 0:
                self.problem.sample_episode *= self.num_actions
                self.problem.gen_step_batch_mask()

                new_bg, _ = self.problem.step_batch(
                    states=dgl.batch([self.bg] * self.num_actions),
                    action=batch_legal_actions.view(batch_size,
                                                    self.num_actions,
                                                    2).transpose(0, 1).reshape(
                                                        -1, 2),
                    action_type=self.action_type,
                    return_sub_reward=False)
                self.problem.sample_episode = self.problem.sample_episode // self.num_actions
                self.problem.gen_step_batch_mask()

                state_eval = self.get_state_eval_from_aux(
                    new_bg, self.alg, aux_model)

                state_eval = state_eval.view(-1, batch_size).t().reshape(
                    batch_size * self.num_actions)
                self.state_eval.append(
                    state_eval.view(batch_size, self.num_actions))

                Q_sa -= state_eval * beta

            terminate_episode = (Q_sa.view(-1, self.num_actions).max(
                dim=1).values < 0.).nonzero().flatten().cpu().numpy()

            for idx in terminate_episode:
                if self.end_of_episode[idx] == episode_len - 1:
                    self.end_of_episode[idx] = i

            best_actions = torch.multinomial(
                F.softmax(Q_sa.view(-1, self.num_actions) / Temperature),
                1).view(-1)

            chose_actions = torch.tensor([
                x if torch.rand(1) > explore_prob else torch.randint(
                    high=self.num_actions, size=(1, )).squeeze()
                for x in best_actions
            ]).cuda()
            chose_actions += self.action_mask

            actions = batch_legal_actions[chose_actions]

            new_states, rewards = self.problem.step_batch(
                states=self.bg, action=actions, action_type=self.action_type)
            R = [reward.item() for reward in rewards]

            # enc_states = [state2QtableKey(self.bg.ndata['label'][k * self.n: (k + 1) * self.n].argmax(dim=1).cpu().numpy()) for k in range(batch_size)]

            [
                ep[k].write(
                    action=actions[k, :],
                    action_idx=best_actions[k] - self.action_mask[k],
                    reward=R[k],
                    q_val=Q_sa.view(-1, self.num_actions)[k, :],
                    actions=batch_legal_actions.view(-1, self.num_actions,
                                                     2)[k, :, :],
                    state_enc=None  #  enc_states[k]
                    # , sub_reward=sub_rewards[:, k].cpu().numpy()
                    ,
                    sub_reward=None,
                    loop_start_position=loop_start_position[k])
                for k in range(batch_size)
            ]

            # # write valid sample episode for training aux model
            # tmp = set()
            # for k in loop_remain_index:
            #     if enc_states[k] in ep[k].enc_state_seq[-4: -1]:
            #         loop_start_position[k] = i
            #         tmp.add(k)
            #     else:
            #         label = ep[k].init_state.ndata['label'][ep[k].label_perm[-1]]
            #         l = label.argmax(dim=1)
            #         self.problem.g = dc(ep[k].init_state)
            #         self.problem.reset_label(label=l)
            #         self.valid_states[k].append(dc(self.problem.g))
            # loop_remain_index = loop_remain_index.difference(tmp)

        self.episodes = ep