Example #1
0
def run(save_loc="ER_100spin/s2v"):

    print("\n----- Running {} -----\n".format(os.path.basename(__file__)))

    ####################################################
    # SET UP ENVIRONMENTAL AND VARIABLES
    ####################################################

    gamma = 1
    step_fact = 1

    env_args = {
        'observables': [Observable.SPIN_STATE],
        'reward_signal': RewardSignal.DENSE,
        'extra_action': ExtraAction.NONE,
        'optimisation_target': OptimisationTarget.CUT,
        'spin_basis': SpinBasis.BINARY,
        'norm_rewards': True,
        'memory_length': None,
        'horizon_length': None,
        'stag_punishment': None,
        'basin_reward': None,
        'reversible_spins': False
    }

    ####################################################
    # SET UP TRAINING AND TEST GRAPHS
    ####################################################

    n_spins_train = 100

    train_graph_generator = RandomErdosRenyiGraphGenerator(
        n_spins=n_spins_train, p_connection=0.15, edge_type=EdgeType.DISCRETE)

    ####
    # Pre-generated test graphs
    ####
    graph_save_loc = "_graphs/testing/ER_100spin_p15_50graphs.pkl"
    graphs_test = load_graph_set(graph_save_loc)
    n_tests = len(graphs_test)

    test_graph_generator = SetGraphGenerator(graphs_test, ordered=True)

    ####################################################
    # SET UP TRAINING AND TEST ENVIRONMENTS
    ####################################################

    train_envs = [
        ising_env.make("SpinSystem", train_graph_generator,
                       int(n_spins_train * step_fact), **env_args)
    ]

    n_spins_test = train_graph_generator.get().shape[0]
    test_envs = [
        ising_env.make("SpinSystem", test_graph_generator,
                       int(n_spins_test * step_fact), **env_args)
    ]

    ####################################################
    # SET UP FOLDERS FOR SAVING DATA
    ####################################################

    data_folder = os.path.join(save_loc, 'data')
    network_folder = os.path.join(save_loc, 'network')

    mk_dir(data_folder)
    mk_dir(network_folder)
    # print(data_folder)
    network_save_path = os.path.join(network_folder, 'network.pth')
    test_save_path = os.path.join(network_folder, 'test_scores.pkl')
    loss_save_path = os.path.join(network_folder, 'losses.pkl')

    ####################################################
    # SET UP AGENT
    ####################################################

    nb_steps = 8000000

    network_fn = lambda: MPNN(n_obs_in=train_envs[0].observation_space.shape[1
                                                                             ],
                              n_layers=3,
                              n_features=64,
                              n_hid_readout=[],
                              tied_weights=False)

    agent = DQN(
        train_envs,
        network_fn,
        init_network_params=None,
        init_weight_std=0.01,
        double_dqn=True,
        clip_Q_targets=True,
        replay_start_size=1500,
        replay_buffer_size=10000,  # 20000
        gamma=gamma,  # 1
        update_target_frequency=2500,  # 500
        update_learning_rate=False,
        initial_learning_rate=1e-4,
        peak_learning_rate=1e-4,
        peak_learning_rate_step=20000,
        final_learning_rate=1e-4,
        final_learning_rate_step=200000,
        update_frequency=32,  # 1
        minibatch_size=64,  # 128
        max_grad_norm=None,
        weight_decay=0,
        update_exploration=True,
        initial_exploration_rate=1,
        final_exploration_rate=0.05,  # 0.05
        final_exploration_step=800000,  # 40000
        adam_epsilon=1e-8,
        logging=False,
        loss="mse",
        save_network_frequency=400000,
        network_save_path=network_save_path,
        evaluate=True,
        test_envs=test_envs,
        test_episodes=n_tests,
        test_frequency=50000,  # 10000
        test_save_path=test_save_path,
        test_metric=TestMetric.MAX_CUT,
        seed=None)

    print("\n Created DQN agent with network:\n\n", agent.network)

    #############
    # TRAIN AGENT
    #############
    start = time.time()
    agent.learn(timesteps=nb_steps, verbose=True)
    print(time.time() - start)

    agent.save()

    ############
    # PLOT - learning curve
    ############
    data = pickle.load(open(test_save_path, 'rb'))
    data = np.array(data)

    fig_fname = os.path.join(network_folder, "training_curve")

    plt.plot(data[:, 0], data[:, 1])
    plt.xlabel("Training run")
    plt.ylabel("Mean reward")
    if agent.test_metric == TestMetric.ENERGY_ERROR:
        plt.ylabel("Energy Error")
    elif agent.test_metric == TestMetric.BEST_ENERGY:
        plt.ylabel("Best Energy")
    elif agent.test_metric == TestMetric.CUMULATIVE_REWARD:
        plt.ylabel("Cumulative Reward")
    elif agent.test_metric == TestMetric.MAX_CUT:
        plt.ylabel("Max Cut")
    elif agent.test_metric == TestMetric.FINAL_CUT:
        plt.ylabel("Final Cut")

    plt.savefig(fig_fname + ".png", bbox_inches='tight')
    plt.savefig(fig_fname + ".pdf", bbox_inches='tight')

    plt.clf()

    ############
    # PLOT - losses
    ############
    data = pickle.load(open(loss_save_path, 'rb'))
    data = np.array(data)

    fig_fname = os.path.join(network_folder, "loss")

    N = 50
    data_x = np.convolve(data[:, 0], np.ones((N, )) / N, mode='valid')
    data_y = np.convolve(data[:, 1], np.ones((N, )) / N, mode='valid')

    plt.plot(data_x, data_y)
    plt.xlabel("Timestep")
    plt.ylabel("Loss")

    plt.yscale("log")
    plt.grid(True)

    plt.savefig(fig_fname + ".png", bbox_inches='tight')
    plt.savefig(fig_fname + ".pdf", bbox_inches='tight')
Example #2
0
def run(save_loc="ER_40spin/eco",
        graph_save_loc="_graphs/validation/ER_40spin_p15_100graphs.pkl",
        batched=True,
        max_batch_size=None):

    print("\n----- Running {} -----\n".format(os.path.basename(__file__)))

    ####################################################
    # NETWORK LOCATION
    ####################################################
    # info_str = "train_mpnn"

    date = datetime.datetime.now().strftime("%Y-%m")
    data_folder = os.path.join(save_loc, 'data')
    network_folder = os.path.join(save_loc, 'network')

    print("data folder :", data_folder)
    print("network folder :", network_folder)

    test_save_path = os.path.join(network_folder, 'test_scores.pkl')
    network_save_path = os.path.join(network_folder, 'network_best.pth')

    print("network params :", network_save_path)

    ####################################################
    # NETWORK SETUP
    ####################################################

    network_fn = MPNN
    network_args = {
        'n_layers': 3,
        'n_features': 64,
        'n_hid_readout': [],
        'tied_weights': False
    }

    ####################################################
    # SET UP ENVIRONMENTAL AND VARIABLES
    ####################################################

    gamma = 0.95
    step_factor = 2

    env_args = {
        'observables': DEFAULT_OBSERVABLES,
        'reward_signal': RewardSignal.BLS,
        'extra_action': ExtraAction.NONE,
        'optimisation_target': OptimisationTarget.CUT,
        'spin_basis': SpinBasis.BINARY,
        'norm_rewards': True,
        'memory_length': None,
        'horizon_length': None,
        'stag_punishment': None,
        'basin_reward': 1. / 40,
        'reversible_spins': True
    }

    ####################################################
    # LOAD VALIDATION GRAPHS
    ####################################################

    graphs_test = load_graph_set(graph_save_loc)

    ####################################################
    # SETUP NETWORK TO TEST
    ####################################################

    test_env = ising_env.make("SpinSystem",
                              SingleGraphGenerator(graphs_test[0]),
                              graphs_test[0].shape[0] * step_factor,
                              **env_args)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.device(device)
    print("Set torch default device to {}.".format(device))

    network = network_fn(n_obs_in=test_env.observation_space.shape[1],
                         **network_args).to(device)

    network.load_state_dict(torch.load(network_save_path, map_location=device))
    for param in network.parameters():
        param.requires_grad = False
    network.eval()

    print(
        "Sucessfully created agent with pre-trained MPNN.\nMPNN architecture\n\n{}"
        .format(repr(network)))

    ####################################################
    # TEST NETWORK ON VALIDATION GRAPHS
    ####################################################

    results, results_raw, history = test_network(network,
                                                 env_args,
                                                 graphs_test,
                                                 device,
                                                 step_factor,
                                                 return_raw=True,
                                                 return_history=True,
                                                 batched=batched,
                                                 max_batch_size=max_batch_size)

    results_fname = "results_" + os.path.splitext(
        os.path.split(graph_save_loc)[-1])[0] + ".pkl"
    results_raw_fname = "results_" + os.path.splitext(
        os.path.split(graph_save_loc)[-1])[0] + "_raw.pkl"
    history_fname = "results_" + os.path.splitext(
        os.path.split(graph_save_loc)[-1])[0] + "_history.pkl"

    for res, fname, label in zip(
        [results, results_raw, history],
        [results_fname, results_raw_fname, history_fname],
        ["results", "results_raw", "history"]):
        save_path = os.path.join(data_folder, fname)
        res.to_pickle(save_path)
        print("{} saved to {}".format(label, save_path))
Example #3
0
def run(save_loc="pretrained_agent/s2v",
        network_save_loc="experiments_new/pretrained_agent/networks/s2v/network_best_ER_200spin.pth",
        graph_save_loc="_graphs/benchmarks/ising_125spin_graphs.pkl",
        batched=True,
        max_batch_size=5):

    print("\n----- Running {} -----\n".format(os.path.basename(__file__)))

    ####################################################
    # FOLDER LOCATIONS
    ####################################################

    print("save location :", save_loc)
    print("network params :", network_save_loc)
    mk_dir(save_loc)

    ####################################################
    # NETWORK SETUP
    ####################################################

    network_fn = MPNN
    network_args = {
        'n_layers': 3,
        'n_features': 64,
        'n_hid_readout': [],
        'tied_weights': False
    }

    ####################################################
    # SET UP ENVIRONMENTAL AND VARIABLES
    ####################################################

    step_factor = 1

    env_args = {'observables':[Observable.SPIN_STATE],
                'reward_signal':RewardSignal.DENSE,
                'extra_action':ExtraAction.NONE,
                'optimisation_target':OptimisationTarget.CUT,
                'spin_basis':SpinBasis.BINARY,
                'norm_rewards':True,
                'memory_length':None,
                'horizon_length':None,
                'stag_punishment':None,
                'basin_reward':None,
                'reversible_spins':False}

    ####################################################
    # LOAD VALIDATION GRAPHS
    ####################################################

    graphs_test = load_graph_set(graph_save_loc)

    ####################################################
    # SETUP NETWORK TO TEST
    ####################################################

    test_env = ising_env.make("SpinSystem",
                              SingleGraphGenerator(graphs_test[0]),
                              graphs_test[0].shape[0] * step_factor,
                              **env_args)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.device(device)
    print("Set torch default device to {}.".format(device))

    network = network_fn(n_obs_in=test_env.observation_space.shape[1],
                         **network_args).to(device)

    network.load_state_dict(torch.load(network_save_loc, map_location=device))
    for param in network.parameters():
        param.requires_grad = False
    network.eval()

    print("Sucessfully created agent with pre-trained MPNN.\nMPNN architecture\n\n{}".format(repr(network)))

    ####################################################
    # TEST NETWORK ON VALIDATION GRAPHS
    ####################################################

    results, results_raw, history = test_network(network, env_args, graphs_test, device, step_factor,
                                                 return_raw=True, return_history=True, n_attempts=50,
                                                 batched=batched, max_batch_size=max_batch_size)

    results_fname = "results_" + os.path.splitext(os.path.split(graph_save_loc)[-1])[0] + ".pkl"
    results_raw_fname = "results_" + os.path.splitext(os.path.split(graph_save_loc)[-1])[0] + "_raw.pkl"
    history_fname = "results_" + os.path.splitext(os.path.split(graph_save_loc)[-1])[0] + "_history.pkl"

    for res, fname, label in zip([results, results_raw, history],
                                 [results_fname, results_raw_fname, history_fname],
                                 ["results", "results_raw", "history"]):
        save_path = os.path.join(save_loc, fname)
        res.to_pickle(save_path)
        print("{} saved to {}".format(label, save_path))
Example #4
0
def __test_network_batched(network,
                           env_args,
                           graphs_test,
                           device=None,
                           step_factor=1,
                           n_attempts=50,
                           return_raw=False,
                           return_history=False,
                           max_batch_size=None):

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.device(device)

    # HELPER FUNCTION FOR NETWORK TESTING

    acting_in_reversible_spin_env = env_args['reversible_spins']

    if env_args['reversible_spins']:
        # If MDP is reversible, both actions are allowed.
        if env_args['spin_basis'] == SpinBasis.BINARY:
            allowed_action_state = (0, 1)
        elif env_args['spin_basis'] == SpinBasis.SIGNED:
            allowed_action_state = (1, -1)
    else:
        # If MDP is irreversible, only return the state of spins that haven't been flipped.
        if env_args['spin_basis'] == SpinBasis.BINARY:
            allowed_action_state = 0
        if env_args['spin_basis'] == SpinBasis.SIGNED:
            allowed_action_state = 1

    def predict(states):

        qs = network(states)

        if acting_in_reversible_spin_env:
            if qs.dim() == 1:
                actions = [qs.argmax().item()]
            else:
                actions = qs.argmax(1, True).squeeze(1).cpu().numpy()
            return actions
        else:
            if qs.dim() == 1:
                x = (states.squeeze()[:, 0] == allowed_action_state).nonzero()
                actions = [x[qs[x].argmax().item()].item()]
            else:
                disallowed_actions_mask = (states[:, :, 0] !=
                                           allowed_action_state)
                qs_allowed = qs.masked_fill(disallowed_actions_mask, -1000)
                actions = qs_allowed.argmax(1, True).squeeze(1).cpu().numpy()
            return actions

    # NETWORK TESTING

    results = []
    results_raw = []
    if return_history:
        history = []

    n_attempts = n_attempts if env_args["reversible_spins"] else 1

    for j, test_graph in enumerate(graphs_test):

        i_comp = 0
        i_batch = 0
        t_total = 0

        n_spins = test_graph.shape[0]
        n_steps = int(n_spins * step_factor)

        test_env = ising_env.make("SpinSystem",
                                  SingleGraphGenerator(test_graph), n_steps,
                                  **env_args)

        print("Running greedy solver with +1 initialisation of spins...",
              end="...")
        # Calculate the greedy cut with all spins initialised to +1
        greedy_env = deepcopy(test_env)
        greedy_env.reset(spins=np.array([1] * test_graph.shape[0]))

        greedy_agent = Greedy(greedy_env)
        greedy_agent.solve()

        greedy_single_cut = greedy_env.get_best_cut()
        greedy_single_spins = greedy_env.best_spins

        print("done.")

        if return_history:
            actions_history = []
            rewards_history = []
            scores_history = []

        best_cuts = []
        init_spins = []
        best_spins = []

        greedy_cuts = []
        greedy_spins = []

        while i_comp < n_attempts:

            if max_batch_size is None:
                batch_size = n_attempts
            else:
                batch_size = min(n_attempts - i_comp, max_batch_size)

            i_comp_batch = 0

            if return_history:
                actions_history_batch = [[None] * batch_size]
                rewards_history_batch = [[None] * batch_size]
                scores_history_batch = []

            test_envs = [None] * batch_size
            best_cuts_batch = [-1e3] * batch_size
            init_spins_batch = [[] for _ in range(batch_size)]
            best_spins_batch = [[] for _ in range(batch_size)]

            greedy_envs = [None] * batch_size
            greedy_cuts_batch = []
            greedy_spins_batch = []

            obs_batch = [None] * batch_size

            print("Preparing batch of {} environments for graph {}.".format(
                batch_size, j),
                  end="...")

            for i in range(batch_size):
                env = deepcopy(test_env)
                obs_batch[i] = env.reset()
                test_envs[i] = env
                greedy_envs[i] = deepcopy(env)
                init_spins_batch[i] = env.best_spins
            if return_history:
                scores_history_batch.append(
                    [env.calculate_score() for env in test_envs])

            print("done.")

            # Calculate the max cut acting w.r.t. the network
            t_start = time.time()

            # pool = mp.Pool(processes=16)

            k = 0
            while i_comp_batch < batch_size:
                t1 = time.time()
                # Note: Do not convert list of np.arrays to FloatTensor, it is very slow!
                # see: https://github.com/pytorch/pytorch/issues/13918
                # Hence, here we convert a list of np arrays to a np array.
                obs_batch = torch.FloatTensor(np.array(obs_batch)).to(device)
                actions = predict(obs_batch)
                obs_batch = []

                if return_history:
                    scores = []
                    rewards = []

                i = 0
                for env, action in zip(test_envs, actions):

                    if env is not None:

                        obs, rew, done, info = env.step(action)

                        if return_history:
                            scores.append(env.calculate_score())
                            rewards.append(rew)

                        if not done:
                            obs_batch.append(obs)
                        else:
                            best_cuts_batch[i] = env.get_best_cut()
                            best_spins_batch[i] = env.best_spins
                            i_comp_batch += 1
                            i_comp += 1
                            test_envs[i] = None
                    i += 1
                    k += 1

                if return_history:
                    actions_history_batch.append(actions)
                    scores_history_batch.append(scores)
                    rewards_history_batch.append(rewards)

                # print("\t",
                #       "Par. steps :", k,
                #       "Env steps : {}/{}".format(k/batch_size,n_steps),
                #       'Time: {0:.3g}s'.format(time.time()-t1))

            t_total += (time.time() - t_start)
            i_batch += 1
            print("Finished agent testing batch {}.".format(i_batch))

            if env_args["reversible_spins"]:
                print(
                    "Running greedy solver with {} random initialisations of spins for batch {}..."
                    .format(batch_size, i_batch),
                    end="...")

                for env in greedy_envs:
                    Greedy(env).solve()
                    cut = env.get_best_cut()
                    greedy_cuts_batch.append(cut)
                    greedy_spins_batch.append(env.best_spins)

                print("done.")

            if return_history:
                actions_history += actions_history_batch
                rewards_history += rewards_history_batch
                scores_history += scores_history_batch

            best_cuts += best_cuts_batch
            init_spins += init_spins_batch
            best_spins += best_spins_batch

            if env_args["reversible_spins"]:
                greedy_cuts += greedy_cuts_batch
                greedy_spins += greedy_spins_batch

            # print("\tGraph {}, par. steps: {}, comp: {}/{}".format(j, k, i_comp, batch_size),
            #       end="\r" if n_spins<100 else "")

        i_best = np.argmax(best_cuts)
        best_cut = best_cuts[i_best]
        sol = best_spins[i_best]

        mean_cut = np.mean(best_cuts)

        if env_args["reversible_spins"]:
            idx_best_greedy = np.argmax(greedy_cuts)
            greedy_random_cut = greedy_cuts[idx_best_greedy]
            greedy_random_spins = greedy_spins[idx_best_greedy]
            greedy_random_mean_cut = np.mean(greedy_cuts)
        else:
            greedy_random_cut = greedy_single_cut
            greedy_random_spins = greedy_single_spins
            greedy_random_mean_cut = greedy_single_cut

        print(
            'Graph {}, best(mean) cut: {}({}), greedy cut (rand init / +1 init) : {} / {}.  ({} attempts in {}s)\t\t\t'
            .format(j, best_cut, mean_cut, greedy_random_cut,
                    greedy_single_cut, n_attempts, np.round(t_total, 2)))

        results.append([
            best_cut, sol, mean_cut, greedy_single_cut, greedy_single_spins,
            greedy_random_cut, greedy_random_spins, greedy_random_mean_cut,
            t_total / (n_attempts)
        ])

        results_raw.append(
            [init_spins, best_cuts, best_spins, greedy_cuts, greedy_spins])

        if return_history:
            history.append([
                np.array(actions_history).T.tolist(),
                np.array(scores_history).T.tolist(),
                np.array(rewards_history).T.tolist()
            ])

    results = pd.DataFrame(data=results,
                           columns=[
                               "cut", "sol", "mean cut",
                               "greedy (+1 init) cut", "greedy (+1 init) sol",
                               "greedy (rand init) cut",
                               "greedy (rand init) sol",
                               "greedy (rand init) mean cut", "time"
                           ])

    results_raw = pd.DataFrame(
        data=results_raw,
        columns=["init spins", "cuts", "sols", "greedy cuts", "greedy sols"])

    if return_history:
        history = pd.DataFrame(data=history,
                               columns=["actions", "scores", "rewards"])

    if return_raw == False and return_history == False:
        return results
    else:
        ret = [results]
        if return_raw:
            ret.append(results_raw)
        if return_history:
            ret.append(history)
        return ret
Example #5
0
def __test_network_sequential(network,
                              env_args,
                              graphs_test,
                              step_factor=1,
                              n_attempts=50,
                              return_raw=False,
                              return_history=False):

    if return_raw or return_history:
        raise NotImplementedError(
            "I've not got to this yet!  Used the batched test script (it's faster anyway)."
        )

    results = []

    n_attempts = n_attempts if env_args["reversible_spins"] else 1

    for i, test_graph in enumerate(graphs_test):

        n_steps = int(test_graph.shape[0] * step_factor)

        best_cut = -1e3
        best_spins = []

        greedy_random_cut = -1e3
        greedy_random_spins = []

        greedy_single_cut = -1e3
        greedy_single_spins = []

        times = []

        test_env = ising_env.make("SpinSystem",
                                  SingleGraphGenerator(test_graph), n_steps,
                                  **env_args)
        net_agent = Network(network,
                            test_env,
                            record_cut=False,
                            record_rewards=False,
                            record_qs=False)

        greedy_env = deepcopy(test_env)
        greedy_env.reset(spins=np.array([1] * test_graph.shape[0]))
        greedy_agent = Greedy(greedy_env)

        greedy_agent.solve()

        greedy_single_cut = greedy_env.get_best_cut()
        greedy_single_spins = greedy_env.best_spins

        for k in range(n_attempts):

            net_agent.reset(clear_history=True)
            greedy_env = deepcopy(test_env)
            greedy_agent = Greedy(greedy_env)

            tstart = time.time()
            net_agent.solve()
            times.append(time.time() - tstart)

            cut = test_env.get_best_cut()
            if cut > best_cut:
                best_cut = cut
                best_spins = test_env.best_spins

            greedy_agent.solve()

            greedy_cut = greedy_env.get_best_cut()
            if greedy_cut > greedy_random_cut:
                greedy_random_cut = greedy_cut
                greedy_random_spins = greedy_env.best_spins

            # print('\nGraph {}, attempt : {}/{}, best cut : {}, greedy cut (rand init / +1 init) : {} / {}\t\t\t'.format(
            #     i + 1, k, n_attemps, best_cut, greedy_random_cut, greedy_single_cut),
            #     end="\r")
            print(
                '\nGraph {}, attempt : {}/{}, best cut : {}, greedy cut (rand init / +1 init) : {} / {}\t\t\t'
                .format(i + 1, k, n_attempts, best_cut, greedy_random_cut,
                        greedy_single_cut),
                end=".")

        results.append([
            best_cut, best_spins, greedy_single_cut, greedy_single_spins,
            greedy_random_cut, greedy_random_spins,
            np.mean(times)
        ])

    return pd.DataFrame(data=results,
                        columns=[
                            "cut", "sol", "greedy (+1 init) cut",
                            "greedy (+1 init) sol", "greedy (rand init) cut",
                            "greedy (rand init) sol", "time"
                        ])
Example #6
0
def run(save_loc="GBMLGG_100/s2v"):

    print("\n----- Running {} -----\n".format(os.path.basename(__file__)))

    ####################################################
    # SET UP ENVIRONMENTAL AND VARIABLES
    ####################################################

    gamma = 1
    step_fact = 1

    env_args = {
        'observables': [Observable.SPIN_STATE],
        'reward_signal': RewardSignal.DENSE,
        'extra_action': ExtraAction.NONE,
        'optimisation_target': OptimisationTarget.PVALUE,
        'spin_basis': SpinBasis.BINARY,
        'norm_rewards': False,
        'memory_length': None,
        'horizon_length': None,
        'stag_punishment': None,
        'basin_reward': None,
        'reversible_spins': False
    }

    ####################################################
    # SET UP TRAINING AND TEST GRAPHS
    ####################################################

    k = 4
    n_spins_train = k

    # train_graph_generator = RandomErdosRenyiGraphGenerator(n_spins=n_spins_train,p_connection=0.15,edge_type=EdgeType.DISCRETE)

    ####
    # Pre-generated test graphs
    ####
    # graph_save_loc = "/home2/wsdm/gyy/eco-dqn_v1/_graphs/testing/ER_200spin_p15_50graphs.pkl"
    # graphs_test = load_graph_set(graph_save_loc)
    # n_tests = len(graphs_test)
    n_tests = 2

    # test_graph_generator = SetGraphGenerator(graphs_test, ordered=True)

    ####################################################
    # SET UP TRAINING AND TEST ENVIRONMENTS
    ####################################################
    train_list = [('COADREAD', 100, 15), ('GBMLGG', 100, 6), ('STAD', 100, 15)]

    test_list = [
        'HNSC', 'ACC', 'LGG', 'KIPAN', 'UVM', 'CESC', 'BRCA', 'UCEC', 'OV',
        'DLBC', 'STAD', 'UCS', 'PRAD', 'CHOL', 'PAAD', 'TGCT', 'LUAD', 'STES',
        'GBMLGG', 'LIHC', 'BLCA', 'KIRC', 'KIRP', 'COAD', 'GBM', 'THCA',
        'READ', 'PCPG', 'COADREAD', 'LUSC', 'KICH', 'SARC'
    ]
    # test_list = [('GBMLGG', 400, 5), ('GBMLGG', 500, 5), ('GBMLGG', 600, 5), ('GBMLGG', 700, 5)]
    # test_list = ['STAD', 'GBMLGG', 'COADREAD']

    # mut_file_path = '/home2/wsdm/gyy/comet_v1/example_datasets/temp/{}_our_pnum={}.m2'
    mut_file_path = '/home2/wsdm/gyy/comet_v1/example_datasets/our/{}_our.m2'

    test_envs = [
        ising_env.make("SpinSystem",
                       mut_file_path.format(cancer_name),
                       int(n_spins_train * step_fact),
                       minFreq=5,
                       **env_args) for cancer_name in test_list
    ]
    '''
    test_envs = [ising_env.make("SpinSystem",
                                      mut_file_path.format(cancer_name, str(pnum)),
                                      int(n_spins_train*step_fact),
                                      minFreq=minfreq,
                                      **env_args) for cancer_name, pnum, minfreq in test_list]
    '''
    '''
    n_spins_test = train_graph_generator.get().shape[0]
    test_envs = [ising_env.make("SpinSystem",
                                mut_file_path,
                                int(n_spins_test*step_fact),
                                **env_args)]
    '''

    ####################################################
    # SET UP FOLDERS FOR SAVING DATA
    ####################################################

    data_folder = os.path.join(save_loc, 'data')
    network_folder = os.path.join(save_loc, 'network')

    mk_dir(data_folder)
    mk_dir(network_folder)
    # print(data_folder)
    network_save_path = os.path.join(network_folder, 'network.pth')
    test_save_path = os.path.join(network_folder, 'test_scores.pkl')
    loss_save_path = os.path.join(network_folder, 'losses.pkl')

    ####################################################
    # SET UP AGENT
    ####################################################

    nb_steps = 10000000

    network_fn = lambda: MPNN(n_obs_in_g=test_envs[0].observation_space.shape[
        1] + 1,
                              n_layers=2,
                              n_features=32,
                              n_hid_readout=[],
                              tied_weights=False)

    agent = DQN(
        test_envs,
        network_fn,
        init_network_params=None,
        init_weight_std=0.5,
        double_dqn=False,
        clip_Q_targets=True,
        replay_start_size=200,
        replay_buffer_size=3200,  # 20000
        gamma=gamma,  # 1
        update_target_frequency=10,  # 500
        update_learning_rate=True,
        initial_learning_rate=1e-2,
        peak_learning_rate=1e-2,
        peak_learning_rate_step=2000,
        final_learning_rate=1e-3,
        final_learning_rate_step=4000,
        update_frequency=4,  # 1
        minibatch_size=64,  # 128
        max_grad_norm=None,
        weight_decay=0,
        update_exploration=True,
        initial_exploration_rate=1,
        final_exploration_rate=0.1,  # 0.05
        final_exploration_step=10000,  # 40000
        adam_epsilon=1e-8,
        logging=False,
        loss="mse",
        save_network_frequency=4000,
        network_save_path=network_save_path,
        evaluate=True,
        test_envs=test_envs,
        test_episodes=n_tests,
        test_frequency=500,  # 10000
        test_save_path=test_save_path,
        test_metric=TestMetric.CUMULATIVE_REWARD,
        seed=None)

    print("\n Created DQN agent with network:\n\n", agent.network)

    #############
    # EVAL AGENT
    #############

    agent.load(
        '/home2/wsdm/gyy/eco-dqn_v2/experiments/GBMLGG_100/train/GBMLGG_100/s2v/network/network32000.pth'
    )
    agent.evaluate_agent()