コード例 #1
0
ファイル: main_gui.py プロジェクト: fspinolo/GridWorld
    def setup_world(self):
    """
    Build out the Grid World and initialize agent
    """
        self.cells = []
        rewards = []

        # stats bar
        stats = Frame(self)
        stats.pack(side=TOP)
        self.status = Label(stats, text="Status: Idle")
        self.status.pack(side=LEFT)
        self.result = Label(stats, text="| Last Result: None")
        self.result.pack(side=LEFT)
        self.episode = Label(stats, text="| Episode #0")
        self.episode.pack(side=LEFT)
        self.winrate = Label(stats, text="| Winrate (last 100): N/A")
        self.winrate.pack(side=LEFT) 
        self.avgsteps = Label(stats, text="| Avg Steps (last 100): N/A")
        self.avgsteps.pack(side=LEFT)

        # grid
        with open('tetris_world.txt', 'r') as f:
            rewards = [
                [-1 * int(x) for x in line.split(',')]
                for line in f.readlines()
            ]
        for row in range(22):
            self.cells.append([])
            row_container = Frame(self)
            row_container.pack(side=TOP)
            for col in range(22):
                cell_container = Frame(row_container, height=40, width=40)
                cell_container.pack_propagate(False)
                cell_container.pack(side=LEFT)
                cell = Button(cell_container, state=DISABLED)
                cell.pack(fill=BOTH, expand=1)
                if rewards[row][col] == -1:
                    cell.config(bg="black")
                elif rewards[row][col] == 1:
                    cell.config(
                        text="$", font=self.fonts['large'],
                        bg="green", fg="white"
                    )
                else:
                    cell.config(
                        bg="white", fg="black", font=self.fonts['small']
                    )
                self.cells[row].append(cell)

        # init agent
        self.agent = Agent(rewards)
コード例 #2
0
ファイル: main_gui.py プロジェクト: mikegwyn17/GridWorld
 def __init__(self, root):
     super().__init__(root)
     self.grid(column=0, row=0, sticky=(N, S, E, W))
     self.columnconfigure(0, weight=1)
     self.rowconfigure(0, weight=1)
     self.arrows = ["\U00002191", "\U00002193", "\U00002190", "\U00002192"]
     self.fonts = {
         'small': font.Font(family="Arial", size=6),
         'medium': font.Font(family="Arial", size=14),
         'large': font.Font(family="Arial", size=24)
     }
     self.setup_world()
     self.agent = Agent(self.rewards)
コード例 #3
0
ファイル: stats.py プロジェクト: fspinolo/GridWorld
"""Run the agent without the GUI and tabulate results in CSV format"""

from sarsa_lambda import Agent

rewards = []
with open('tetris_world.txt', 'r') as f:
    rewards = [
        [-1 * int(x) for x in line.split(',')]
        for line in f.readlines()
    ]

agent = Agent(rewards)
tenths = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
episodes = 10000
interval = 1000
top_label = (
    ',' + ','.join([str(x) for x in range(0, episodes + 1, interval)]) + '\n'
)

with open('winrate.csv', 'w') as w, open('avgsteps.csv', 'w') as a:
    print('Testing Gammas...')
    w.write('Gamma\n')
    a.write('Gamma\n')
    w.write(top_label)
    a.write(top_label)
    for gamma in tenths:
        agent.gamma = gamma
        wstr = astr = '{},0'.format(gamma)
        for i in range(episodes):
            agent.episode()
            if agent.episodes > 1 and not agent.episodes % interval:
コード例 #4
0
ファイル: main_gui.py プロジェクト: mikegwyn17/GridWorld
class GridWorldGUI(Frame):

    def __init__(self, root):
        super().__init__(root)
        self.grid(column=0, row=0, sticky=(N, S, E, W))
        self.columnconfigure(0, weight=1)
        self.rowconfigure(0, weight=1)
        self.arrows = ["\U00002191", "\U00002193", "\U00002190", "\U00002192"]
        self.fonts = {
            'small': font.Font(family="Arial", size=6),
            'medium': font.Font(family="Arial", size=14),
            'large': font.Font(family="Arial", size=24)
        }
        self.setup_world()
        self.agent = Agent(self.rewards)

    def setup_world(self):
        self.cells = []
        self.rewards = []
        with open('world_map.txt', 'r') as f:
            self.rewards = [
                [-1 * int(x) for x in line.split(',')]
                for line in f.readlines()
            ]
        for row in range(22):
            self.cells.append([])
            row_container = Frame(self)
            row_container.pack(side=TOP)
            for col in range(22):
                cell_container = Frame(row_container, height=40, width=40)
                cell_container.pack_propagate(False)
                cell_container.pack(side=LEFT)
                cell = Button(cell_container, state=DISABLED)
                cell.pack(fill=BOTH, expand=1)
                if self.rewards[row][col] == -1:
                    cell.config(bg="black")
                elif self.rewards[row][col] == 1:
                    cell.config(text="$", font=self.fonts['large'], bg="green", fg="white")
                else:
                    cell.config(bg="white", fg="black", font=self.fonts['small'])
                self.cells[row].append(cell)

    def update_grid(self):
        for i in range(1, 21):
            for j in range(1, 21):
                if (
                    (i, j) == (self.agent.row, self.agent.col) and
                    not self.rewards[i][j]
                ):
                    self.cells[i][j].config(text="\U00002620", bg="blue", fg="white", font=self.fonts['large'])
                elif not self.rewards[i][j]:
                    arrow = self.arrows[self.agent.best_action(i, j)]
                    best = self.agent.best_value(i, j)
                    if best < 0.25:
                        afont = self.fonts['small']
                    elif best < 0.75:
                        afont = self.fonts['medium']
                    else:
                        afont = self.fonts['large']
                    self.cells[i][j].config(text=arrow, bg="white", fg="black", font=afont)
        self.update()

    def episode(self):
        self.agent.spawn()
        self.update_grid()
        while not self.rewards[self.agent.row][self.agent.col]:
            self.agent.take_step()
            r = self.rewards[self.agent.row][self.agent.col]
            if r:
                last = 'GOAL' if r == 1 else 'WALL'
                print('Terminating on {}'.format(last))
            self.update_grid()
        self.agent.epsilon -= 0.001

    def run(self, episodes):
        for i in range(episodes):
            self.episode()
コード例 #5
0
ファイル: main_gui.py プロジェクト: fspinolo/GridWorld
class GridWorldGUI(Frame):
    """GUI for sarsa(lambda) agent"""

    def __init__(self, root):
    """
    Constructor
    :param root: parent window for the GUI
    """
        super().__init__(root)
        self.pack(fill=BOTH, expand=1)
        self.columnconfigure(0, weight=1)
        self.rowconfigure(0, weight=1)
        self.arrows = ["\U00002191", "\U00002193", "\U00002190", "\U00002192"]
        self.fonts = {
            'small': font.Font(family="Arial", size=12),
            'medium': font.Font(family="Arial", size=18),
            'large': font.Font(family="Arial", size=28)
        }
        self.setup_world()

    def setup_world(self):
    """
    Build out the Grid World and initialize agent
    """
        self.cells = []
        rewards = []

        # stats bar
        stats = Frame(self)
        stats.pack(side=TOP)
        self.status = Label(stats, text="Status: Idle")
        self.status.pack(side=LEFT)
        self.result = Label(stats, text="| Last Result: None")
        self.result.pack(side=LEFT)
        self.episode = Label(stats, text="| Episode #0")
        self.episode.pack(side=LEFT)
        self.winrate = Label(stats, text="| Winrate (last 100): N/A")
        self.winrate.pack(side=LEFT) 
        self.avgsteps = Label(stats, text="| Avg Steps (last 100): N/A")
        self.avgsteps.pack(side=LEFT)

        # grid
        with open('tetris_world.txt', 'r') as f:
            rewards = [
                [-1 * int(x) for x in line.split(',')]
                for line in f.readlines()
            ]
        for row in range(22):
            self.cells.append([])
            row_container = Frame(self)
            row_container.pack(side=TOP)
            for col in range(22):
                cell_container = Frame(row_container, height=40, width=40)
                cell_container.pack_propagate(False)
                cell_container.pack(side=LEFT)
                cell = Button(cell_container, state=DISABLED)
                cell.pack(fill=BOTH, expand=1)
                if rewards[row][col] == -1:
                    cell.config(bg="black")
                elif rewards[row][col] == 1:
                    cell.config(
                        text="$", font=self.fonts['large'],
                        bg="green", fg="white"
                    )
                else:
                    cell.config(
                        bg="white", fg="black", font=self.fonts['small']
                    )
                self.cells[row].append(cell)

        # init agent
        self.agent = Agent(rewards)

    def update_grid(self):
    """Update the grid's arrows and agent position"""
        for i in range(1, 21):
            for j in range(1, 21):
                if (
                    (i, j) == (self.agent.row, self.agent.col) and
                    not self.agent.rewards[i][j]
                ):
                    self.cells[i][j].config(
                        text="\U00002620", bg="blue", fg="white",
                        font=self.fonts['large']
                    )
                elif not self.agent.rewards[i][j]:
                    arrow = self.arrows[self.agent.best_action(i, j)]
                    confidence = self.agent.confidence(i, j)
                    if confidence < 0.05:
                        afont = self.fonts['small']
                    elif confidence < 0.2:
                        afont = self.fonts['medium']
                    else:
                        afont = self.fonts['large']
                    self.cells[i][j].config(
                        text=arrow, bg="white", fg="black", font=afont
                    )
        self.update()

    def update_stats(self, status):
    """
    Update the stats bar at the top of the GUI
    :param status: current agent status
    """
        self.status.config(text="Status: {}".format(status))
        self.episode.config(
            text="Episode #{}".format(self.agent.episodes)
        )
        if self.agent.episodes > 1:
            result = (
                'GOAL'
                if self.agent.rewards[self.agent.row][self.agent.col] == 1
                else 'WALL'
            )
            self.result.config(
                text="Last Result: {} ({} steps)"
                .format(result, self.agent.steps)
            )
            winrate = sum(self.agent.goal) / len(self.agent.last100)
            avgsteps = sum(self.agent.last100) / len(self.agent.last100)
            self.winrate.config(text="Winrate (last 100): {}".format(winrate))
            self.avgsteps.config(
                text="Avg Steps (last 100): {}".format(avgsteps)
            )
        self.update()

    def run(self, episodes):
    """
    Show the agent navigating Grid World
    :param episodes: the number of episodes to run
    """
        self.update_stats("Running")
        for i in range(episodes):
            self.agent.spawn()
            while not self.agent.rewards[self.agent.row][self.agent.col]:
                self.agent.take_step()
                self.update_grid()
            self.update_stats("Running")
        self.update_stats("Idle")

    def train(self, episodes):
    """
    Train the agent without showing movement, update arrows every 100 episodes
    :param episodes: number of episodes to train
    """
        self.update_stats('Training')
        for i in range(episodes):
            self.agent.episode()
            if i % 100 == 0:
                self.update_stats('Training')
                self.update_grid()
        self.update_stats('Idle')

    def reset_agent(self):
    """Reset the agent and grid arrows"""
        self.agent.reset()
        self.update_grid()