def load_replay_memory(self):
        files = []
        for file in os.listdir('./{}'.format(self.memory_folder)):
            if file.endswith('.pt') and file.startswith('GW'):
                splitted = file.split('_')
                gw = int(splitted[2][3:])
                sup = strtobool(splitted[3][4:])
                mem_len = int(splitted[1])


                if 'hp' in file:
                    mark = strtobool(splitted[4][5:])
                    hole_positions = ast.literal_eval(splitted[5][3:-3])
                else:
                    hole_positions = False
                    mark = strtobool(splitted[4][5:-3])
                if gw == self.gridworld_size:
                    if sup == self.supervision:
                        if mark == self.Mark:
                            if not self.constant_change and hole_positions:
                                files.append([file, hole_positions, mem_len])
                            elif self.constant_change and not hole_positions:
                                files.append([file, hole_positions, mem_len])

        if len(files) == 0:
            raise ValueError('No such memory available!')
        outcome = max(files, key=lambda x: x[2])
        self.hole_positions = outcome[1]

        if not self.constant_change:
            if self.embedding and self.network.__class__.__name__ != 'SimpleCNN':
                self.env = GridworldGym(headless=self.headless, dynamic_holes=self.dynamic_holes, dynamic_start=self.dynamic_start, embedding=self.embedding, specific_holes=outcome[1], constant_change=self.constant_change,
                                        gridworld_size=self.gridworld_size)
                self.network = self.net_name(embedding=self.embedding, gw_size=self.self.gridworld_size).to(device)
            else:
                self.env = GridworldGym(headless=self.headless, dynamic_holes=self.dynamic_holes, dynamic_start=self.dynamic_start,constant_change=self.constant_change, specific_holes=outcome[1], gridworld_size=self.gridworld_size)
                self.network = self.net_name(gw_size=self.gridworld_size).to(device)


        with open('{}/{}'.format(self.memory_folder, outcome[0]), 'rb') as pf:
            memory = pickle.load(pf)

        if len(memory) < self.memory_size:
            factor = int(self.memory_size / outcome[2])
            rm = factor * memory.memory

        for mem in rm:
            self.memory.push(mem)
        return
 def __init__(self, load_q_values=True, gridworld_size=7):
     self.Q_values = {}
     self.load_pickles(load_q_values)
     self.epsilon = 0.99
     self.learning_rate = 0.05
     self.gridworld_size = gridworld_size
     self.future_discount = 0.99
     self.selfishness = 0.5
     self.writer = tf.summary.FileWriter(f'logs/Q_Tab_Grid/{str(datetime.now())}')
     self.step = 0
     self.env = GridworldGym(headless=True, dynamic_holes=True, dynamic_start=False, constant_change=True, gridworld_size=gridworld_size)
     self.episodes = 0
     self.episode_setps = 0
     self.episode_durations = []
     self.log_q_values=[[]]
     # self.rewards = []
     self.total_death = 0
     self.total_succeed = 0
Ejemplo n.º 3
0
def main(gridworld_size, supervision, specific_holes):


    change = True if specific_holes else False
    dynamic_holes = True
    # holes = [[5,1],[1,2],[3,3],[2,4],[5,5]]
    use_holes = False
    # supervision = False
    Mark = False

    # Get the environment and extract the number of actions.
    env = GridworldGym(dynamic_holes=dynamic_holes,constant_change=change, gridworld_size=gridworld_size, specific_holes=specific_holes, self_play=False)
    if not change:
        specific_holes = env.hole_pos

    done = False


    s = env.reset()

    counter = 0

    old_states = [0,0,s]
    while counter < replay_memory_size:
        if done:
            s = env.reset()

        if Mark:
            a = get_input()
        else:
            a = env.optimal_choice()


        s_next, r, done, _ = env.step(a)

        old_states = [old_states[1], old_states[2], s_next]
        if old_states[0] is old_states[2]:
            done = True
        if supervision:
            opt_a = env.optimal_choice()

        if not supervision:
            replay_memory.push((s, a, r, s_next, done))
        else:
            replay_memory.push((s, a, opt_a, r, s_next, done))
        if Mark:
            time.sleep(0.3)
        counter +=1

    if not change:
        save_replay_memory(replay_memory, gridworld_size, supervision, hole_pos=specific_holes, Mark=Mark)
    else:
        save_replay_memory(replay_memory, gridworld_size, supervision, Mark=Mark)
Ejemplo n.º 4
0
import pygame
import random
import pickle
import numpy as np
import tensorflow as tf

from multiprocessing import Process

from datetime import datetime

from GridworldGym import GridworldGym

env = GridworldGym()


class EmphaticQLearner():
    def __init__(self, load_q_values=True):
        self.Q_values = {}
        self.load_pickles(load_q_values)
        self.epsilon = 0.99
        self.learning_rate = 0.05
        self.future_discount = 0.99
        self.selfishness = 0.5
        self.writer = tf.summary.FileWriter(
            f'logs/LRLearning3.0/{str(datetime.now())}')
        self.step = 0
        self.log_q_values = [[]]

        for i in range(10000000):
            self.step += 1
            if self.epsilon > 0.1:
    def __init__(self, network=NETWORK, num_episodes=NUM_EPISODES,
                 memory_size=MEMORY_SIZE, seed=SEED, discount_factor=DISCOUNT_FACTOR,
                 headless=True, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE,
                 dynamic_holes=DYNAMIC_HOLES, dynamic_start=DYNAMIC_START, load_episode=False, save_every=SAVE_EVERY,
                 plot_every=PLOT_EVERY, plotting=False, embedding=EMBEDDING, gridworld_size=7, change=False,
                 supervision=False, google=False, print_every=PRINT_EVERY, name=None, load_memory=False, Mark=False):

        self.num_episodes = num_episodes
        self.memory = ReplayMemory(memory_size)
        self.memory_size = memory_size
        self.discount_factor = discount_factor
        self.Mark = Mark

        self.headless = headless
        self.dynamic_holes = dynamic_holes
        self.dynamic_start = dynamic_start
        self.net_name = network
        self.memory_folder = 'memories'

        if embedding and network.__class__.__name__ != 'SimpleCNN':
            self.env = GridworldGym(headless=headless, dynamic_holes=dynamic_holes, dynamic_start=dynamic_start, embedding=embedding, constant_change=change,
                                    gridworld_size=gridworld_size)
            self.network = network(embedding=embedding, gw_size=gridworld_size).to(device)
        else :
            self.env = GridworldGym(headless=headless, dynamic_holes=dynamic_holes, dynamic_start=dynamic_start,constant_change=change, gridworld_size=gridworld_size, self_play=False)
            self.network = network(gw_size=gridworld_size).to(device)
        # self.initialize(seed)
        self.batch_size = batch_size
        self.save_every = save_every
        self.plot_every = plot_every
        self.print_every = print_every
        self.embedding = embedding
        self.gridworld_size = gridworld_size
        self.supervision = supervision
        self.optimizer = optim.Adam(self.network.parameters(), learning_rate, weight_decay=0.01)
        self.episode_number = 0
        self.num_deaths = 0
        self.episode_durations = []
        self.number_of_deaths = []
        self.google = google
        self.load_memory = load_memory
        self.rewards = []
        self.steps = 0
        self.constant_change = change
        self.supervision_episodes = 1000
        self.loss = 0
        self.mark = Mark
        self.experiment_name = 'checkpoint_{}_DH={}_DS={}_em={}_final2_sup={}_load_mem={}_size={}_i={}'.format(self.network.__class__.__name__, change, dynamic_start, self.embedding, supervision, load_memory, self.gridworld_size, name)
        self.exp_folder = 'checkpoints'
        self.fig_folder = 'figures'
        self.smooth_factor = 100
        if load_memory:
            self.load_replay_memory()

        if self.embedding and not check_conv_net(self.network):
            raise ValueError('We cannot combine embeddings with a Feed-Forward network!')

        if load_episode:
            last_checkpoint, episode = find_last_checkpoint(self.experiment_name, self.exp_folder)
            if last_checkpoint:
                print('Continue from episode {}'.format(self.episode_number))
                self.episode_number = episode
                self.load_model(last_checkpoint)
class trainer_Q_network(object):
    def __init__(self, network=NETWORK, num_episodes=NUM_EPISODES,
                 memory_size=MEMORY_SIZE, seed=SEED, discount_factor=DISCOUNT_FACTOR,
                 headless=True, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE,
                 dynamic_holes=DYNAMIC_HOLES, dynamic_start=DYNAMIC_START, load_episode=False, save_every=SAVE_EVERY,
                 plot_every=PLOT_EVERY, plotting=False, embedding=EMBEDDING, gridworld_size=7, change=False,
                 supervision=False, google=False, print_every=PRINT_EVERY, name=None, load_memory=False, Mark=False):

        self.num_episodes = num_episodes
        self.memory = ReplayMemory(memory_size)
        self.memory_size = memory_size
        self.discount_factor = discount_factor
        self.Mark = Mark

        self.headless = headless
        self.dynamic_holes = dynamic_holes
        self.dynamic_start = dynamic_start
        self.net_name = network
        self.memory_folder = 'memories'

        if embedding and network.__class__.__name__ != 'SimpleCNN':
            self.env = GridworldGym(headless=headless, dynamic_holes=dynamic_holes, dynamic_start=dynamic_start, embedding=embedding, constant_change=change,
                                    gridworld_size=gridworld_size)
            self.network = network(embedding=embedding, gw_size=gridworld_size).to(device)
        else :
            self.env = GridworldGym(headless=headless, dynamic_holes=dynamic_holes, dynamic_start=dynamic_start,constant_change=change, gridworld_size=gridworld_size, self_play=False)
            self.network = network(gw_size=gridworld_size).to(device)
        # self.initialize(seed)
        self.batch_size = batch_size
        self.save_every = save_every
        self.plot_every = plot_every
        self.print_every = print_every
        self.embedding = embedding
        self.gridworld_size = gridworld_size
        self.supervision = supervision
        self.optimizer = optim.Adam(self.network.parameters(), learning_rate, weight_decay=0.01)
        self.episode_number = 0
        self.num_deaths = 0
        self.episode_durations = []
        self.number_of_deaths = []
        self.google = google
        self.load_memory = load_memory
        self.rewards = []
        self.steps = 0
        self.constant_change = change
        self.supervision_episodes = 1000
        self.loss = 0
        self.mark = Mark
        self.experiment_name = 'checkpoint_{}_DH={}_DS={}_em={}_final2_sup={}_load_mem={}_size={}_i={}'.format(self.network.__class__.__name__, change, dynamic_start, self.embedding, supervision, load_memory, self.gridworld_size, name)
        self.exp_folder = 'checkpoints'
        self.fig_folder = 'figures'
        self.smooth_factor = 100
        if load_memory:
            self.load_replay_memory()

        if self.embedding and not check_conv_net(self.network):
            raise ValueError('We cannot combine embeddings with a Feed-Forward network!')

        if load_episode:
            last_checkpoint, episode = find_last_checkpoint(self.experiment_name, self.exp_folder)
            if last_checkpoint:
                print('Continue from episode {}'.format(self.episode_number))
                self.episode_number = episode
                self.load_model(last_checkpoint)
        # if not plotting:
        #     self.run_episodes()
        # else:
        #     self.plot_results()

    def load_replay_memory(self):
        files = []
        for file in os.listdir('./{}'.format(self.memory_folder)):
            if file.endswith('.pt') and file.startswith('GW'):
                splitted = file.split('_')
                gw = int(splitted[2][3:])
                sup = strtobool(splitted[3][4:])
                mem_len = int(splitted[1])


                if 'hp' in file:
                    mark = strtobool(splitted[4][5:])
                    hole_positions = ast.literal_eval(splitted[5][3:-3])
                else:
                    hole_positions = False
                    mark = strtobool(splitted[4][5:-3])
                if gw == self.gridworld_size:
                    if sup == self.supervision:
                        if mark == self.Mark:
                            if not self.constant_change and hole_positions:
                                files.append([file, hole_positions, mem_len])
                            elif self.constant_change and not hole_positions:
                                files.append([file, hole_positions, mem_len])

        if len(files) == 0:
            raise ValueError('No such memory available!')
        outcome = max(files, key=lambda x: x[2])
        self.hole_positions = outcome[1]

        if not self.constant_change:
            if self.embedding and self.network.__class__.__name__ != 'SimpleCNN':
                self.env = GridworldGym(headless=self.headless, dynamic_holes=self.dynamic_holes, dynamic_start=self.dynamic_start, embedding=self.embedding, specific_holes=outcome[1], constant_change=self.constant_change,
                                        gridworld_size=self.gridworld_size)
                self.network = self.net_name(embedding=self.embedding, gw_size=self.self.gridworld_size).to(device)
            else:
                self.env = GridworldGym(headless=self.headless, dynamic_holes=self.dynamic_holes, dynamic_start=self.dynamic_start,constant_change=self.constant_change, specific_holes=outcome[1], gridworld_size=self.gridworld_size)
                self.network = self.net_name(gw_size=self.gridworld_size).to(device)


        with open('{}/{}'.format(self.memory_folder, outcome[0]), 'rb') as pf:
            memory = pickle.load(pf)

        if len(memory) < self.memory_size:
            factor = int(self.memory_size / outcome[2])
            rm = factor * memory.memory

        for mem in rm:
            self.memory.push(mem)
        return

    def initialize(self, seed):
        random.seed(seed)
        torch.manual_seed(seed)
        # self.env.seed(seed)


    def train(self, epsilon):
        # DO NOT MODIFY THIS FUNCTION

        # don't learn without some decent experience
        if len(self.memory) < self.batch_size:
            return None

        # random transition batch is taken from experience replay memory
        transitions = self.memory.sample(self.batch_size)

        # transition is a list of 4-tuples, instead we want 4 vectors (as torch.Tensor's)
        if self.supervision:
            state, action, opt_action, reward, next_state, done = zip(*transitions)
            opt_action = torch.tensor(opt_action, dtype=torch.int64).to(device)
        else:
            state, action, reward, next_state, done = zip(*transitions)

            # opt_action = torch.tensor(a, dtype=torch.int64).to(device)
        # convert to PyTorch and define types
        state = torch.tensor(state, dtype=torch.float).to(device)

        action = torch.tensor(action, dtype=torch.int64).to(device)

        next_state = torch.tensor(next_state, dtype=torch.float).to(device)
        reward = torch.tensor(reward, dtype=torch.float).to(device)
        done = torch.tensor(done, dtype=torch.uint8).to(device)  # Boolean

        # compute the q value
        q_val = compute_q_val(self.network, state, action)

        with torch.no_grad():  # Don't compute gradient info for the target (semi-gradient)
            target = compute_target(self.network, reward, next_state, done, self.discount_factor)

        # loss is measured from error between current and newly expected Q values

        if self.supervision and self.episode_number < self.supervision_episodes:
            action_probs = self.network(state)

            l = torch.ones(action_probs.shape)

            l[np.arange(len(l)), opt_action] = 0

            loss = F.smooth_l1_loss(q_val, target)
            #
            # one_hot = F.one_hot(opt_action).type(torch.FloatTensor)
            # super_loss = F.cross_entropy(action_probs, opt_action)
            Q = action_probs + l
            action_e = action_probs[np.arange(len(action_probs)), opt_action]
            super_loss = torch.sum(torch.max(Q) - action_e)



            loss += super_loss
            # torch.mean(loss)
        else:
            loss = F.smooth_l1_loss(q_val, target)

        # backpropagation of loss to Neural Network (PyTorch magic)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()


    def run_episodes(self):
        # Count the steps (do not reset at episode start, to compute epsilon)
        episode_duration = 0

        while self.episode_number < self.num_episodes:
            if self.episode_number % self.print_every == 0:
                print('Currently working on episode {}'.format(self.episode_number))
            # print('Currently working on episode {}'.format(self.episode_number))
            done = False
            episode_duration = 0
            self.episode_number += 1
            s = self.env.reset()
            rew = 0

            while not done:

                epsilon = get_epsilon(self.steps)
                episode_duration += 1
                self.steps += 1
                a = select_action(self.network, s, epsilon)
                if self.supervision:
                    opt_a = self.env.optimal_choice()
                s_next, r, done, _ = self.env.step(a)
                rew += r

                # Push a transition
                if not self.supervision:
                    self.memory.push((s, a, r, s_next, done))
                else:
                    self.memory.push((s, a, opt_a, r, s_next, done))
                s = s_next
                self.loss = self.train(epsilon)

            self.episode_durations.append(episode_duration)
            # print('The episode lasted {}'.format(episode_duration))
            # print('Currently: Epsilon is {} after {} Episodes'.format(epsilon, self.episode_number))
            # print('The outcome is {}'.format(r))
            self.rewards.append(r)
            if r == -1:
                self.num_deaths += 1
            self.number_of_deaths.append(self.num_deaths)

            if self.episode_number % self.save_every == 0:
                self.save_model()

            if self.episode_number % self.plot_every == 0:
                if not self.google:
                    self.plot_results()
                else:
                    self.save_results()

    def save_model(self):

        file_path = '{}/{}_ep={}.pt'.format(self.exp_folder, self.experiment_name, self.episode_number)
        torch.save({
            'epoch': self.episode_number,
            'model_state_dict': self.network.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': self.loss,
            'rewards': self.rewards,
            'episode_durations': self.episode_durations,
            'number_deaths': self.num_deaths,
            'deaths': self.number_of_deaths
        }, file_path)

    def load_model(self, file_path=False):
        if not file_path:
            file_path = '{}/{}_ep={}.pt'.format(self.exp_folder, self.experiment_name, self.episode_number)
        checkpoint = torch.load(file_path)

        self.network.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.loss = checkpoint['loss']
        self.num_deaths = checkpoint['number_deaths']
        self.episode_durations = checkpoint['episode_durations']
        self.steps = sum(self.episode_durations)
        self.rewards = checkpoint['rewards']
        self.number_of_deaths = checkpoint['deaths']

    def plot_results(self):
        fig = plt.figure()
        plt.plot(smooth(self.episode_durations, self.smooth_factor))
        plt.title('Episode durations per episode')
        fig.savefig('{}/{}_ep_dur.png'.format(self.fig_folder, self.experiment_name))

        fig = plt.figure()
        plt.plot(smooth(self.rewards, self.smooth_factor))
        plt.title('Reward per episode')
        fig.savefig('{}/{}_rewards.png'.format(self.fig_folder, self.experiment_name))

        fig = plt.figure()
        plt.plot(self.number_of_deaths)
        plt.title('Number of deaths over time')
        fig.savefig('{}/{}_nem_deaths.png'.format(self.fig_folder, self.experiment_name))

        print('The current number of deaths is {} after {} episodes'.format(self.num_deaths, self.episode_number))
        plt.close()

    def save_results(self):
        fn = '{}/{}_data.pt'.format(self.fig_folder, self.experiment_name)

        data = {'Episode_durations': self.episode_durations, 'Rewards': self.rewards, 'Number of Deaths': self.number_of_deaths}

        with open(fn, "wb") as pf:
            pickle.dump(data, pf)
class EmphaticQLearner():

    def __init__(self, load_q_values=True, gridworld_size=7):
        self.Q_values = {}
        self.load_pickles(load_q_values)
        self.epsilon = 0.99
        self.learning_rate = 0.05
        self.gridworld_size = gridworld_size
        self.future_discount = 0.99
        self.selfishness = 0.5
        self.writer = tf.summary.FileWriter(f'logs/Q_Tab_Grid/{str(datetime.now())}')
        self.step = 0
        self.env = GridworldGym(headless=True, dynamic_holes=True, dynamic_start=False, constant_change=True, gridworld_size=gridworld_size)
        self.episodes = 0
        self.episode_setps = 0
        self.episode_durations = []
        self.log_q_values=[[]]
        # self.rewards = []
        self.total_death = 0
        self.total_succeed = 0


    def train(self):
        while self.episodes < 10000:
            self.step += 1
            if self.epsilon > 0.1:
                self.epsilon = self.epsilon * 0.9999

            self.Q_learning()


    def save_rewards(self):

        fn = 'dynamic_pickles/{}_{}_{}.pt'.format(self.gridworld_size, 'Table',
                                                 datetime.now().timestamp())

        with open(fn, "wb") as pf:
            pickle.dump((self.rewards, self.gridworld_size, 'Table', self.episode_durations), pf)
        print('Saved an experiment for the Q_table with size {}.'.format(self.gridworld_size))


    def log_scalar(self, tag, value, global_step):
        summary = tf.Summary()
        summary.value.add(tag=tag, simple_value=value)
        self.writer.add_summary(summary, global_step=global_step)
        self.writer.flush()

    def log_histogram(self, tag, values, global_step, bins):
        counts, bin_edges = np.histogram(values, bins=bins)

        hist = tf.HistogramProto()
        hist.min = float(np.min(values))
        hist.max = float(np.max(values))
        hist.num = int(np.prod(values.shape))
        hist.sum = float(np.sum(values))
        hist.sum_squares = float(np.sum(values**2))

        bin_edges = bin_edges[1:]

        for edge in bin_edges:
            hist.bucket_limit.append(edge)
        for c in counts:
            hist.bucket.append(c)

        summary = tf.Summary()
        summary.value.add(tag=tag, histo=hist)
        self.writer.add_summary(summary, global_step=global_step)
        self.writer.flush()

    def load_pickles(self, load_q_values):
        if not load_q_values:
            with open('Q_values.pickle', 'wb') as handle:
                pickle.dump(self.Q_values, handle, protocol=pickle.HIGHEST_PROTOCOL)
                self.rewards = [0]
        else:
            with open('Q_values.pickle', 'rb') as handle:
                self.Q_values = pickle.load(handle)
            with open('coins_collected.pickle', 'rb') as handle:
                self.rewards = pickle.load(handle)
            with open('goombas_killed.pickle', 'rb') as handle:
                self.enemies_killed = pickle.load(handle)

    def do_game_step(self, move):

        next_state, reward, done, info = self.env.step(move)

        if done:
            self.env.reset()
            self.rewards[-1] += reward
            self.log_scalar('reward', self.rewards[-1], self.step)
            self.log_scalar('epsilon', self.epsilon, self.step)
            self.log_scalar('mean_q', np.mean(self.log_q_values[-1]), self.step)
            # self.log_histogram('q_values', np.array(self.log_q_values[-1]), self.step, 20)
            self.rewards.append(0)
            self.log_q_values.append([])

        return next_state, reward, done, info

    def level_to_key(self, obs):
        obs1 = tuple(map(tuple, obs))
        return obs1


    def get_best_action(self, state):
        max_Q = -np.inf
        best_action = None
        state_x, state_y = state[0], state[1]
        if state not in self.Q_values:
            self.Q_values[state] = {}
        for action in range(4):
            if action not in self.Q_values[state]:
                self.Q_values[state][action] = 1
            if self.Q_values[state][action] >= max_Q:
                max_Q = self.Q_values[state][action]
                best_action = action
        return best_action, max_Q

    def Q_learning(self):

        state = tuple(self.env.agent_position)
        best_action, max_Q = self.get_best_action(state)

        if np.random.random() > self.epsilon:
            action = best_action
            next_state, reward, done, info = self.do_game_step(action)
        else:
            action = random.choice(range(4))
            next_state, reward, done, info = self.do_game_step(action)

        next_state = tuple(self.env.agent_position)

        _, new_Q = self.get_best_action(next_state)
        value =  (reward + self.future_discount * new_Q)
        self.Q_values[state][action] = (1-self.learning_rate)*self.Q_values[state][action] + self.learning_rate*value
        self.episode_setps += 1
        # self.rewards[-1] += reward
        if 'death' in info:
            self.total_death += info['death']
        if 'succeed' in info:
            self.total_succeed += info['succeed']
        self.log_q_values[-1].append(max_Q)
        if done:
            self.episodes += 1

            self.episode_durations.append(self.episode_setps)
            self.episode_setps = 0
            if len(self.rewards) > 500:
                average_last = np.mean(self.rewards[-500:])
            else:
                average_last = np.mean(self.rewards)
Ejemplo n.º 8
0
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, Flatten, Convolution2D, Permute
from keras.optimizers import Adam
from keras.callbacks import TensorBoard, ModelCheckpoint
import keras.backend as K
from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy, LinearAnnealedPolicy, EpsGreedyQPolicy
from rl.memory import SequentialMemory
from rl.core import Processor
from MarioGym import MarioGym
from GridworldGym import GridworldGym



# Get the environment and extract the number of actions.
env = GridworldGym()
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n
tb_log_dir = 'logs/tmp/{}'.format(time.time())
tb_callback = TensorBoard(log_dir=tb_log_dir, batch_size=32, write_grads=True, write_images=True)
cp = ModelCheckpoint('logs/cp/checkpoint-{episode_reward:.2f}-{epoch:02d}-.h5f', monitor='episode_reward', verbose=0, save_best_only=False, save_weights_only=True, mode='max', period=500)
INPUT_SHAPE = (7, 7)

# HYPERPARAMETERS
TRAINING_STEPS = 5000000
WINDOW_LENGTH = 4
REPLAY_MEMORY = 500000
MAX_EPSILON = 0.5
MIN_EPSILON = 0.0
EPSILON_DECAY_PERIOD = 0.75