def __init__(self, alphaStepSize=1e-4, **kwargs):

        SARSA.__init__(self, alphaStepSize=1e-4, **kwargs)

        self.numFeatures = self.numRays + 1
        self.initializeWeights()
        self.resetElibilityTraces()

        grid = np.arange(0, self.numRays)
        self.plotGrid = grid - np.mean(grid)
예제 #2
0
def evaluate(env, config, q_table, episode, render=False, output=True):
    """
    Evaluate configuration of SARSA on given environment initialised with given Q-table

    :param env (gym.Env): environment to execute evaluation on
    :param config (Dict[str, float]): configuration dictionary containing hyperparameters
    :param q_table (Dict[(Obs, Act), float]): Q-table mapping observation-action to Q-values
    :param episode (int): episodes of training completed
    :param render (bool): flag whether evaluation runs should be rendered
    :param output (bool): flag whether mean evaluation performance should be printed
    :return (float, float): mean and standard deviation of reward received over episodes
    """
    eval_agent = SARSA(
        num_acts=env.action_space.n,
        gamma=config["gamma"],
        epsilon=0.0,
        alpha=config["alpha"],
    )
    eval_agent.q_table = q_table
    episodic_rewards = []
    for eps_num in range(config["eval_episodes"]):
        obs = env.reset()
        if render:
            env.render()
            sleep(1)
        episodic_reward = 0
        done = False

        steps = 0
        while not done and steps <= config["max_episode_steps"]:
            steps += 1
            act = eval_agent.act(obs)
            n_obs, reward, done, info = env.step(act)
            if render:
                env.render()
                sleep(1)

            episodic_reward += reward

            obs = n_obs

        episodic_rewards.append(episodic_reward)

    mean_reward = np.mean(episodic_rewards)
    std_reward = np.std(episodic_rewards)

    if output:
        print(
            f"EVALUATION ({episode}/{CONFIG['total_eps']}): MEAN REWARD OF {mean_reward}"
        )
        if mean_reward >= 0.9:
            print(f"EVALUATION: SOLVED")
        else:
            print(f"EVALUATION: NOT SOLVED!")
    return mean_reward, std_reward
    def __init__(self, alphaStepSize=1e-4, **kwargs):

        SARSA.__init__(self, alphaStepSize=1e-4, **kwargs)


        self.numFeatures = self.numRays + 1
        self.initializeWeights()
        self.resetElibilityTraces()

        grid = np.arange(0, self.numRays)
        self.plotGrid = grid - np.mean(grid)
    def __init__(self, numInnerBins=4, numOuterBins=4, binCutoff=0.5, alphaStepSize=0.2,
                 useQLearningUpdate= False, **kwargs):

        SARSA.__init__(self, alphaStepSize=0.2, **kwargs)

        self.numInnerBins=numInnerBins
        self.numOuterBins=numOuterBins
        self.numBins=numInnerBins + numOuterBins
        self.binCutoff=binCutoff
        self.useQLearningUpdate = useQLearningUpdate
        self.initializeQValues()
        self.initializeBinData()
        self.resetElibilityTraces()
        self.eligibilityTraceThreshold = 0.1
    def __init__(self, numInnerBins=4, numOuterBins=4, binCutoff=0.5, alphaStepSize=0.2, forceDriveStraight=False,
                 useQLearningUpdate= False, **kwargs):

        SARSA.__init__(self, alphaStepSize=0.2, **kwargs)

        self.numInnerBins=numInnerBins
        self.numOuterBins=numOuterBins
        self.numBins=numInnerBins + numOuterBins
        self.binCutoff=binCutoff
        self.useQLearningUpdate = useQLearningUpdate
        self.forceDriveStraight = forceDriveStraight
        self.initializeQValues()
        self.initializeBinData()
        self.resetElibilityTraces()
        self.eligibilityTraceThreshold = 0.1
예제 #6
0
def main():
    nProcess = multi.cpu_count()
    name = ["MCSteps", "SSteps", "Qsteps"]# Change this!
    algQs=list()
    algrews=list()
    steps = [250000] * 4
    for i, agent in enumerate([MonteCarlo(env_in=EZ21()), SARSA(env_in=EZ21()), QLearn(env_in=EZ21())]):
        print(name[i])
        Qs = list()
        algrews.append([])
        run_sum = 0
        for eps in steps:
            run_sum += eps
            print(run_sum)
            agent.n = eps
            agent.iter_opt()
            Qs.append(deepcopy(agent.Q))
        with multi.Pool(nProcess) as pool:
            algrews.append(pool.map(play_rounds, algQs[i]))

        with open(name[i] + "_algQs", 'wb') as myfile:
            pickle.dump(algQs, myfile)
        with open(name[i] + "_algrews", 'wb') as myfile:
            pickle.dump(algrews, myfile)
    return
예제 #7
0
	def __init__(self):
		self.ball = Ball()
		self.paddle = Paddle()
		self.agent = QLearning(10, 0.7, 0.05)
		self.sarsa_agent = SARSA(10, 0.7, 0.05)
		self.state = (self.ball.x, self.ball.y, self.ball.velocity_x, self.ball.velocity_y, self.paddle.y)
		self.score = 0
		self.reward = 0
		self.game_number = 0
		self.scores = []
		self.finished_training = False
		self.finished_testing = False
		self.is_active = True
		self.previous_state = None
		self.previous_action = None	
		self.training_stats = []
		self.test_stats = []
def evaluate(env, config, q_table, render=False):
    """
    Evaluate configuration of SARSA on given environment initialised with given Q-table

    :param env (gym.Env): environment to execute evaluation on
    :param config (Dict[str, float]): configuration dictionary containing hyperparameters
    :param q_table (Dict[(Obs, Act), float]): Q-table mapping observation-action to Q-values
    :param render (bool): flag whether evaluation runs should be rendered
    :return (float, float, int): mean and standard deviation of return received over episodes, number
        of negative returns
    """
    eval_agent = SARSA(
        num_acts=env.action_space.n,
        gamma=config["gamma"],
        epsilon=0.0,
        alpha=config["alpha"],
    )
    eval_agent.q_table = q_table
    episodic_returns = []
    for eps_num in range(config["eval_episodes"]):
        obs = env.reset()
        if render:
            env.render()
            sleep(1)
        episodic_return = 0
        done = False

        steps = 0
        while not done and steps <= config["max_episode_steps"]:
            steps += 1
            act = eval_agent.act(obs)
            n_obs, reward, done, info = env.step(act)
            if render:
                env.render()
                sleep(1)

            episodic_return += reward

            obs = n_obs

        episodic_returns.append(episodic_return)

    mean_return = np.mean(episodic_returns)
    std_return = np.std(episodic_returns)
    negative_returns = sum([ret < 0 for ret in episodic_returns])
    return mean_return, std_return, negative_returns
예제 #9
0
파일: main.py 프로젝트: arame/Sarsa_Lab4
def main():
    total_episodes = 500
    N = 20
    gamma = 0.9
    epsilon = 0.999
    decay = 0.99
    alpha = 0.5
    _dungeon = Dungeon(N)
    no_actions = 4
    no_states = N * N
    q_values = np.zeros((no_states, no_actions))
    state_position_dict = {
        i * N + j: (i, j)
        for i in range(N) for j in range(N)
    }
    position_state_dict = {v: k for k, v in state_position_dict.items()}
    no_of_steps = []
    epsilons = []
    for _ep in range(total_episodes):
        no_steps = 0
        _ = _dungeon.reset()
        sarsa = SARSA(_dungeon, epsilon, decay, alpha, gamma, q_values)
        position_agent = _dungeon.position_agent
        s_current = position_state_dict[position_agent[0], position_agent[1]]
        position_exit = _dungeon.position_exit
        s_exit = position_state_dict[position_exit[0], position_exit[1]]
        if s_current == s_exit:
            continue
        a_next = sarsa(
            s_current, sarsa.q_values
        )  # gets the action from the policy (see __call__ method)
        while s_current != s_exit:
            no_steps += 1
            _, r_next, _ = _dungeon.step(a_next)
            position_agent = _dungeon.position_agent
            s_next = position_state_dict[position_agent[0], position_agent[1]]
            a_next_next = sarsa(
                s_next, sarsa.q_values
            )  # gets the action from the policy (see __call__ method)
            sarsa.update_values(s_current, a_next, r_next, s_next, a_next_next)
            s_current = s_next
            a_next = a_next_next
        q_values = sarsa.q_values
        print("For episode " + str(_ep + 1) + " the number of steps were " +
              str(no_steps) + " with epsilon " + str(sarsa.epsilon))
        no_of_steps.append(no_steps)
        epsilons.append(sarsa.epsilon)
        epsilon = sarsa.update_epsilon()
    print("Cell    Q values")
    for i in range(N * N):
        print(i + 1, "   ", q_values[i])
    sarsa.display_values(no_of_steps, epsilons)
예제 #10
0
def main():
    print('New agent online!')
    print('..... Initializing learning algorithm: SARSA')
    ACTIONS = [MOVE, SHOOT, PASS_CLOSE, PASS_FAR, DRIBBLE]
    SARSA = SARSA(ACTIONS)
    print('..... Initializing discretization with CMAC')
    CMAC = CMAC(1,0.5,0.1)
    print('..... Loading HFO environment')
    hfo = HFOEnvironment()
    print('..... Connecting to HFO server')
    hfo.connectToServer(HIGH_LEVEL_FEATURE_SET,
                      'bin/teams/base/config/formations-dt', 6000,
                      'localhost', 'base_left', False)
    print('..... Start training')
    for episode in itertools.count():
        print('..... Starting episode %d' % episode)
        status = IN_GAME
        step = 0
        while status == IN_GAME:
            step += 1
            old_status = status
            # Get the vector of state features for the current state
            features = hfo.getState()
            state = transformFeatures(features)
            print('State: %s' % str(state))

            action = select_action(state)
            hfo.act(action)
            #print('Action: %s' % str(action))
            # Advance the environment and get the game status
            status = hfo.step()
            #print('Status: %s' % str(status))
            print('.......... Step %d: %s - %s - %s' % (step, str(old_status), str(action), str(status)))
        # Check the outcome of the episode
        print('..... Episode ended with %s'% hfo.statusToString(status))
        # Quit if the server goes down
        if status == SERVER_DOWN:
            hfo.act(QUIT)
            break
예제 #11
0
    notrl_tot_steps = 0
    notrl_returns = []
    notrl_steps = []

    # create grid-world instance
    canyon = True
    grid = GridWorld(4, canyon)
    grid.make_maps()

    possible_actions = grid.possible_actions
    world = grid.world
    grid.list_of_maps.reverse()

    # Direct learning on final grid
    print("Direct learning on final grid")
    sarsa = SARSA(grid.final_grid, possible_actions, world)
    Q, returns, episodes, steps = do_task(sarsa, grid,
                                          len(grid.list_of_maps) - 1)
    notrl_returns.append(returns)
    notrl_steps.append(steps)
    notrl_tot_steps += steps[-1]
    print("-" * 80)

    # Incremental transfer learning
    if canyon:
        canyon_str = "(CANYON)"
    else:
        canyon_str = "(NO CANYON)"
    print("Incremental transfer learning", canyon_str)
    Q = None
    for task, current_map in enumerate(grid.list_of_maps):
예제 #12
0
 if (arg == "default"):
     wG = WindyGridworld()
     graphTitle = "SARSA for regular Windy Gridworld"
 elif (arg == "kings"):
     wG = WindyGridworldK()
     graphTitle = "SARSA for Windy Gridworld with King's moves"
 elif (arg == "stochastic"):
     wG = WindyGridworldS()
     graphTitle = "SARSA for Windy Gridworld with Stochasticity"
 else:
     print("Incorrect argument")
     exit()
 numStates = wG.getNumStates()
 numActions = wG.getNumActions()
 discount = wG.getDiscount()
 transitions = wG.getTransition()
 start = wG.getStartState()
 end = wG.getEndState()
 numEpisodes = 200
 yMean = np.zeros((numEpisodes, ))
 seedvals = [30, 46, 73, 92, 29, 65, 8, 50, 11, 81]
 for seedval in seedvals:
     x, y = SARSA(seedval, transitions, numStates, numActions, discount,
                  start, end, numEpisodes)
     yMean += y
 yMean /= len(seedvals)
 plt.plot(yMean, x)
 plt.xlabel("Time Steps")
 plt.ylabel("Episodes")
 plt.title(graphTitle)
 plt.show()
예제 #13
0
                        choices=['qlearn', 'sarsa', 'esarsa'],
                        default='qlearn')


parser = ArgumentParser()
algo_args(parser)
args = parser.parse_args()

for eps in [0.05, 0.2]:
    fname = '{}-{}.csv'.format(args.algo, eps)
    fpath = os.path.join("exps", fname)
    with open(fpath, "w+") as fp:
        total_rewards = 0
        options = {
            "qlearn": lambda: QLearn(eps=eps),
            "sarsa": lambda: SARSA(eps=eps),
            "esarsa": lambda: ExpectedSARSA(eps=eps)
        }
        algo = options.get(args.algo)()

        for episode in range(10000):
            grid = GridWorld()
            agent = Agent()
            s = agent.position()
            actions = [(0, 1), (1, 0), (-1, 0), (0, -1)]
            a = random.choice(actions)
            episode_reward = 0

            def step(s, a):
                s_ = grid.move(s, a)
                agent.position(s_)
예제 #14
0
파일: boxy.py 프로젝트: Balint-H/BBB
import pandas as pd
from decimal import Decimal
from sarsa import SARSA
from q_learn import QLearn
from matplotlib.lines import Line2D

with open(r'C:\Source_files\Python\Pantry\MCQs', "rb") as f:
    MQs = pickle.load(f)
with open(r'C:\Source_files\Python\Pantry\SARSA_Qs', "rb") as f:
    SARSAQs = pickle.load(f)
with open(r'C:\Source_files\Python\BBB\QLong_Qs', "rb") as f:
    QLearnQs = pickle.load(f)

Qs = zip(MQs, SARSAQs, QLearnQs)
MC = MonteCarlo(env_in=EZ21())
SRS = SARSA(env_in=EZ21())
QL = QLearn(env_in=EZ21())
MCrews = list()
SARSArews = list()
QRews = list()
for j, (M, S, L) in enumerate(Qs):
    if j >= 50:
        break
    print(j)
    MCrews.append([])
    SARSArews.append([])
    QRews.append([])
    MC.reset(Q_in=M)
    SRS.reset(Q_in=S)
    QL.reset(Q_in=L)
    for i in range(50000):
예제 #15
0
    # create grid-world instance
    grid = GridWorld()
    grid.make_maps()

    # Change index to get different maps 0-4
    current_map = grid.list_of_maps[0]

    if not current_map:
        print("Map index is out of range.")
        sys.exit()

    possible_actions = grid.possible_actions
    x_lim, y_lim = grid.x_lim, grid.y_lim

    # creates SARSA instance
    sarsa = SARSA(current_map, possible_actions, x_lim, y_lim)

    # initialize algorithm parameters
    old_mean = 0
    delta = 0.000001
    steps = 0

    state = grid.reset_state()
    print("Started at ", state)
    for episode in range(nEp):
        action = sarsa.epsilon_greedy_random_action(state)
        for step in itertools.count():
            new_state, reward = sarsa.take_step(state, action)
            new_action = sarsa.epsilon_greedy_random_action(new_state)
            sarsa.update_Q(state, action, new_state, new_action, reward)
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 17 02:52:37 2019

@author: thoma
"""

from sarsa import SARSA
import gym
import matplotlib.pyplot as plt
import numpy as np

write_path = '../../data/data_long_sarsa.txt'

T = 1000
nb_episodes = 500

env = gym.make('MountainCar-v0')
agent = SARSA(env, T)

lengths = -np.asarray(agent.learn(nb_episodes))
agent.generate_trajectory_file(200, write_path)

plt.plot(
    np.arange(len(lengths))[::5],
    np.convolve(lengths, np.ones(5, ) / 5, mode='same')[::5])
plt.show()
예제 #17
0
class Game:
	def __init__(self):
		self.ball = Ball()
		self.paddle = Paddle()
		self.agent = QLearning(10, 0.7, 0.05)
		self.sarsa_agent = SARSA(10, 0.7, 0.05)
		self.state = (self.ball.x, self.ball.y, self.ball.velocity_x, self.ball.velocity_y, self.paddle.y)
		self.score = 0
		self.reward = 0
		self.game_number = 0
		self.scores = []
		self.finished_training = False
		self.finished_testing = False
		self.is_active = True
		self.previous_state = None
		self.previous_action = None	
		self.training_stats = []
		self.test_stats = []

	def discretize_state(self):
		if self.is_active == False:
			return (-1,-1,-1,-1,-1)

		if self.ball.velocity_x > 0:
			discrete_velocity_x = 1
		else:
			discrete_velocity_x = -1

		if self.ball.velocity_y >= 0.015:
			discrete_velocity_y = 1
		elif self.ball.velocity_y <= -0.015:
			discrete_velocity_y = -1
		else:
			discrete_velocity_y = 0

		discrete_paddle = min(11, int(math.floor(12 * self.paddle.y/(1 - self.paddle.height))))

		discrete_ball_x =  min(11, int(math.floor(12 * self.ball.x)))
		discrete_ball_y =  min(11, int(math.floor(12 * self.ball.y)))

		return (discrete_ball_x, discrete_ball_y, discrete_velocity_x, discrete_velocity_y, discrete_paddle)

	def end_game(self):
		if len(self.scores) == 1000:
			self.scores = self.scores[1:]
		self.scores.append(self.score)
		self.score = 0
		self.game_number += 1
		self.is_active = False

		if self.game_number%1000 == 0:
			average = float(sum(self.scores))/1000.0
			print(self.game_number, average)
			self.training_stats.append((self.game_number, average))

		if self.game_number == 20000:
			self.finished_training = True

	def end_test_game(self):
		self.test_stats.append((self.game_number, self.score))
		self.game_number += 1
		self.score = 0
		self.is_active = False

		if self.game_number == 200:
			self.finished_testing = True

	def check_terminal_state(self, mode):
		if self.ball.x > self.paddle.x:
			if self.ball.y > self.paddle.y and self.ball.y < self.paddle.y + self.paddle.height:
				self.ball.hit_paddle()
				self.score += 1
				return True
			else:
				if mode == 'test':
					self.end_test_game()
					return False
				else:
					self.end_game()
					return False
		else:
			return False

	def update_q(self):
		hit_paddle = self.check_terminal_state('train')
		discrete_state = self.discretize_state()

		if self.is_active == False:
			self.reward = -1.0
			if self.previous_state is not None:
				self.agent.learn(self.previous_state, self.previous_action, self.reward, discrete_state)
			self.previous_state = None
			self.ball = Ball()
			self.paddle = Paddle()
			self.is_active = True
			return

		if hit_paddle is True:
			self.reward = 1.0

		if self.previous_state != None:
			self.agent.learn(self.previous_state, self.previous_action, self.reward, discrete_state)

		new_state = self.discretize_state()
		new_action = self.agent.choose_action(new_state)

		self.previous_state = new_state
		self.previous_action = new_action
		self.paddle.update(new_action)
		self.ball.update()
		self.reward = 0.0

	def update_sarsa(self):
		hit_paddle = self.check_terminal_state('train')
		discrete_state = self.discretize_state()
		action = self.sarsa_agent.choose_action(discrete_state)

		if self.is_active == False:
			self.reward = -1.0
			if self.previous_state is not None:
				self.sarsa_agent.learn(self.previous_state, self.previous_action, self.reward, discrete_state, action)
			self.previous_state = None
			self.ball = Ball()
			self.paddle = Paddle()
			self.is_active = True
			return

		if hit_paddle is True:
			self.reward = 1.0

		if self.previous_state != None:
			self.sarsa_agent.learn(self.previous_state, self.previous_action, self.reward, discrete_state, action)

		new_state = self.discretize_state()
		new_action = self.sarsa_agent.choose_action(new_state)

		self.previous_state = new_state
		self.previous_action = new_action
		self.paddle.update(new_action)
		self.ball.update()
		self.reward = 0.0

	def update_test_q(self):
		hit_paddle = self.check_terminal_state('test')
		discrete_state = self.discretize_state()

		if self.is_active == False:
			self.ball = Ball()
			self.paddle = Paddle()
			self.is_active = True
			return

		new_state = self.discretize_state()
		new_action = self.agent.choose_action(new_state)

		self.paddle.update(new_action)
		self.ball.update()

	def update_test_sarsa(self):
		hit_paddle = self.check_terminal_state('test')
		discrete_state = self.discretize_state()

		if self.is_active == False:
			self.ball = Ball()
			self.paddle = Paddle()
			self.is_active = True
			return

		new_state = self.discretize_state()
		new_action = self.sarsa_agent.choose_action(new_state)

		self.paddle.update(new_action)
		self.ball.update()

	def init_nagent(self, W, B, normalize):
		self.nagent = nnet_agent.NAgent(W, B, normalize)

	def update_test_nagent(self):
		hit_paddle = self.check_terminal_state('test')
	
		if self.is_active == False:
			self.ball = Ball()
			self.paddle = Paddle()
			self.is_active = True
			return

		new_state = (self.ball.x, self.ball.y, self.ball.velocity_x, self.ball.velocity_y, self.paddle.y)
		new_action = self.nagent.choose_action(new_state)
		# print(new_action)

		self.paddle.update(new_action)
		self.ball.update()
		self.state = (self.ball.x, self.ball.y, self.ball.velocity_x, self.ball.velocity_y, self.paddle.y)
예제 #18
0
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 17 15:21:28 2018

@author: Shiratori
"""

import graphics as gfx
import pong_model as pm
from sarsa import SARSA

if __name__ == '__main__':
    # Set up environment
    environment = pm.PongModel(0.5, 0.5, 0.03, 0.01, 0.4, paddleX=1)
    window = gfx.GFX(wall_x=0, player_x=1)
    window.fps = 9e16
    
    # Set up model
    model = SARSA(environment, window, alpha=0.5, gamma=0.99, explore=-1, threshold=-1, 
                  log=True, log_file='test_sarsa_log.txt', mode='test')
    
    # Training
    model.train()
예제 #19
0
def On_Pol(cnstnts=(None, None)):
    if "1" in multi.current_process().name:
        print(('%s began working' % multi.current_process().name))
    agent = SARSA(env_in=EZ21(), n_in=80000, cnst_par=cnstnts)
    return agent.iter_opt(), agent.Q
예제 #20
0
def train(env, config, output=True):
    """
    Train and evaluate SARSA on given environment with provided hyperparameters

    :param env (gym.Env): environment to execute evaluation on
    :param config (Dict[str, float]): configuration dictionary containing hyperparameters
    :param output (bool): flag if mean evaluation results should be printed
    :return (float, List[float], List[float], Dict[(Obs, Act), float]):
        total reward over all episodes, list of means and standard deviations of evaluation
        rewards, final Q-table
    """
    agent = SARSA(
        num_acts=env.action_space.n,
        gamma=config["gamma"],
        epsilon=config["epsilon"],
        alpha=config["alpha"],
    )

    step_counter = 0
    # 100 as estimate of max steps to take in an episode
    max_steps = config["total_eps"] * config["max_episode_steps"]

    total_reward = 0
    evaluation_reward_means = []
    evaluation_reward_stds = []
    evaluation_epsilons = []

    for eps_num in range(config["total_eps"]):
        obs = env.reset()
        episodic_reward = 0
        steps = 0
        done = False

        # take first action
        act = agent.act(obs)

        while not done and steps < config["max_episode_steps"]:
            n_obs, reward, done, info = env.step(act)
            step_counter += 1
            episodic_reward += reward

            agent.schedule_hyperparameters(step_counter, max_steps)
            n_act = agent.act(n_obs)
            agent.learn(obs, act, reward, n_obs, n_act, done)

            obs = n_obs
            act = n_act

        total_reward += episodic_reward

        if eps_num > 0 and eps_num % config["eval_freq"] == 0:
            mean_reward, std_reward = evaluate(env,
                                               config,
                                               agent.q_table,
                                               eps_num,
                                               render=RENDER,
                                               output=output)
            evaluation_reward_means.append(mean_reward)
            evaluation_reward_stds.append(std_reward)
            evaluation_epsilons.append(agent.epsilon)

    return total_reward, evaluation_reward_means, evaluation_reward_stds, evaluation_epsilons, agent.q_table
예제 #21
0
from Qlearning import Qlearning
from sarsa import SARSA

if __name__ == "__main__":
    '''method = 0
    if method==0:
        model = np.load('qlearning_1_1.npz')
        env = CatAndMouseEnv(mode_obstacle=model['mode_obstacle'],mode_mouse=model['mode_mouse'],map=model['map'],mouse=model['mouse'])
        q = Qlearning(mode_obstacle=model['mode_obstacle'],mode_mouse=model['mode_mouse'],map=model['map'],Q=model['Q'],mouse=model['mouse'])
    else:
        model = np.load('sarsa_1_1.npz')
        env = CatAndMouseEnv(mode_obstacle=model['mode_obstacle'],mode_mouse=model['mode_mouse'],map=model['map'],mouse=model['mouse'])
        q = SARSA(mode_obstacle=model['mode_obstacle'],mode_mouse=model['mode_mouse'],map=model['map'],Q=model['Q'],mouse=model['mouse'])
    q.visualization()
    print(model['map'])
    print(model['Q'])
    print(model['mode_obstacle'])
    print(model['mode_mouse'])'''

    method = 0
    if method == 0:
        model = np.load('qlearning_1_1.npz')
        q = Qlearning(8, 8, mode_obstacle=1, mode_mouse=1, Q=model['Q'])
    else:
        model = np.load('sarsa_1_1.npz')
        q = SARSA(8, 8, mode_obstacle=1, mode_mouse=1, Q=model['Q'])
    q.visualization()
    q.test()

    print("env closed")
def train(env, config, output=True):
    """
    Train and evaluate SARSA on given environment with provided hyperparameters

    :param env (gym.Env): environment to execute evaluation on
    :param config (Dict[str, float]): configuration dictionary containing hyperparameters
    :param output (bool): flag whether mean evaluation performance should be printed
    :return (List[float], List[float], List[float], Dict[(Obs, Act)]):
        list of means and standard deviations of evaluation returns, list of epislons, final Q-table
    """
    agent = SARSA(
        num_acts=env.action_space.n,
        gamma=config["gamma"],
        epsilon=config["epsilon"],
        alpha=config["alpha"],
    )

    step_counter = 0
    # 100 as estimate of max steps to take in an episode
    max_steps = config["total_eps"] * config["max_episode_steps"]

    evaluation_return_means = []
    evaluation_return_stds = []
    evaluation_epsilons = []

    for eps_num in range(config["total_eps"]):
        obs = env.reset()
        episodic_return = 0
        steps = 0
        done = False

        # take first action
        act = agent.act(obs)

        while not done and steps < config["max_episode_steps"]:
            n_obs, reward, done, info = env.step(act)
            step_counter += 1
            episodic_return += reward

            agent.schedule_hyperparameters(step_counter, max_steps)
            n_act = agent.act(n_obs)
            agent.learn(obs, act, reward, n_obs, n_act, done)

            obs = n_obs
            act = n_act

        if eps_num % config["eval_freq"] == 0:
            mean_return, std_return, negative_returns = evaluate(
                env,
                config,
                agent.q_table,
                render=RENDER,
            )
            if output:
                print(
                    f"EVALUATION: EP {eps_num} - MEAN RETURN {mean_return} +/- {std_return} ({negative_returns}/{config['eval_episodes']} failed episodes)"
                )
            evaluation_return_means.append(mean_return)
            evaluation_return_stds.append(std_return)
            evaluation_epsilons.append(agent.epsilon)

    return evaluation_return_means, evaluation_return_stds, evaluation_epsilons, agent.q_table