Exemple #1
0
        # Training to prevent errors arising from connected training
        agent = w_QAgent(self.opt_env)
        agent.qlearn(3500)
        rews = utility(agent)

        x = max(rews, self.max_reward)

        return (vector, x)


if __name__ == "__main__":
    max_layer = int(sys.argv[1])
    num_iters = int(sys.argv[2])

    env = WindyGridworld()

    tree = Tree(env, max_layer)
    tree.threshold = float(sys.argv[3])
    tree.initialize()

    tree.BFS(num_iters)

    # Store data
    r_dir = os.path.abspath(os.pardir)
    data_dir = os.path.join(r_dir, "data-wgr")
    txt_dir = os.path.join(data_dir,
                           "bfs_result_{}.txt".format(tree.max_layer))

    a = tree.best_observed_choice()
Exemple #2
0
import gym  # pylint: disable=import-error
import time
from collections import deque
import csv
import ast

from sklearn.model_selection import train_test_split  # pylint: disable=import-error
from wgrenv import WindyGridworld
from w_qlearn import w_QAgent
import matplotlib.pyplot as plt  # pylint: disable=import-error
from termcolor import colored  # pylint: disable=import-error
import multiprocessing as mp
from multiprocessing import Manager
import itertools

env = WindyGridworld()  # reference environment
input_data = []
output_data = []


def make_env(env, mod_seq):
    ref_env = copy.deepcopy(env)
    locations = mod_seq
    for element in locations:
        if element[0] == 0:
            ref_env.jump_cells.append((element[1], element[2]))
        else:
            ref_env.special.append((element[1], element[2]))

    return ref_env
Exemple #3
0
        if show:
            self.env.render()
        return [steps, total, states]

    def print_eval_result(self, output):
        # Print results from evaluation for visualization
        print("Steps taken: {}".format(output[0]))
        print("Total reward: {}".format(output[1]))
        print("States traversed: {}".format(output[2]))

        return


if __name__ == "__main__":
    env = WindyGridworld()

    modified = copy.deepcopy(env)
    modified.jump_cells.append((4, 2))
    modified.special.append((3, 1))
    modified.special.append((3, 4))
    modified.special.append((3, 5))

    start = time.time()
    agent = w_QAgent(modified)
    agent.qlearn(3000, render=False)
    end = time.time()

    series = env.resettable_states()
    vals = []