Esempio n. 1
0
def run_model(game_count=1):
    """
    run model for game_count games
    """

    # Make environment
    env = WhaleEnv(
        config={
            'active_player': 0,
            'seed': datetime.utcnow().microsecond,
            'env_num': 1,
            'num_players': 5
        })
    # Set up agents
    action_num = 3
    agent = DqnAgent(action_num=action_num, player_num=5)
    agent_0 = RandomAgent(action_num=action_num)
    agent_1 = RandomAgent(action_num=action_num)
    agent_2 = RandomAgent(action_num=action_num)
    agent_3 = RandomAgent(action_num=action_num)
    agents = [agent, agent_0, agent_1, agent_2, agent_3]
    env.set_agents(agents)
    agent.load_pretrained()
    for game in range(game_count):

        # Generate data from the environment
        trajectories = env.run(is_training=False)

        # Print out the trajectories
        print('\nEpisode {}'.format(game))
        i = 0
        for trajectory in trajectories:
            print('\tPlayer {}'.format(i))
            p(trajectory[-1])
            i += 1
Esempio n. 2
0
    def __init__(self, name, type_):
        self.score = 0
        self.name = name
        self.hand = None
        self.mask = None
        self.type = type_

        if self.type == PlayerType.Computer:
            self.agent = DqnAgent(train=False)
Esempio n. 3
0
def train_model(max_episodes=1000):
    """
    Trains a DQN agent to play the CartPole game by trial and error

    :return: None
    """

    # buffer = ReplayBuffer()
    # Make environment
    env = WhaleEnv(
        config={
            'allow_step_back': False,
            'allow_raw_data': False,
            'single_agent_mode': False,
            'active_player': 0,
            'record_action': False,
            'seed': 0,
            'env_num': 1,
            'num_players': 5
        })
    # Set a global seed using time
    set_global_seed(datetime.utcnow().microsecond)
    # Set up agents
    action_num = 3
    agent = DqnAgent(dim=1, action_num=action_num)
    agent_0 = RandomAgent(action_num=action_num)
    agent_1 = RandomAgent(action_num=action_num)
    agent_2 = RandomAgent(action_num=action_num)
    agent_3 = RandomAgent(action_num=action_num)
    agents = [agent, agent_0, agent_1, agent_2, agent_3]
    env.set_agents(agents)
    agent.load_pretrained()
    UPDATE_TARGET_RATE = 20
    GAME_COUNT_PER_EPISODE = 2
    min_perf, max_perf = 1.0, 0.0
    for episode_cnt in range(1, max_episodes + 1):
        loss = agent.train(
            collect_gameplay_experiences(env, agents, GAME_COUNT_PER_EPISODE))
        avg_reward = evaluate_training_result(env, agent)
        target_update = episode_cnt % UPDATE_TARGET_RATE == 0
        if avg_reward > max_perf:
            max_perf = avg_reward
            agent.save_weight()
        if avg_reward < min_perf:
            min_perf = avg_reward
        print('{0:03d}/{1} perf:{2:.2f}(min:{3} max:{4})'
              'up:{5:1d} loss:{6}'.format(episode_cnt, max_episodes,
                                          avg_reward, min_perf, max_perf,
                                          target_update, loss))
        if target_update:
            agent.update_target_network()
    # env.close()
    print('training end')
Esempio n. 4
0
def train_model(max_episodes=50000):
    """
    Trains a DQN agent to play the CartPole game by trial and error

    :return: None
    """
    agent = DqnAgent()
    buffer = ReplayBuffer()
    env = gym.make('CartPole-v0')
    for _ in range(100):
        collect_gameplay_experiences(env, agent, buffer)
    for episode_cnt in range(max_episodes):
        collect_gameplay_experiences(env, agent, buffer)
        gameplay_experience_batch = buffer.sample_gameplay_batch()
        loss = agent.train(gameplay_experience_batch)
        avg_reward = evaluate_training_result(env, agent)
        print('Episode {0}/{1} and so far the performance is {2} and '
              'loss is {3}'.format(episode_cnt, max_episodes, avg_reward,
                                   loss[0]))
        if episode_cnt % 20 == 0:
            agent.update_target_network()
    env.close()
    print('No bug lol!!!')
Esempio n. 5
0
def test_model(
    model_location=config.DEFAULT_MODEL_LOCATION,
    gamma=config.DEFAULT_GAMMA,
    verbose=config.DEFAULT_VERBOSITY_OPTION,
    learning_rate=config.DEFAULT_LEARNING_RATE,
    checkpoint_location=config.DEFAULT_CHECKPOINT_LOCATION,
    persist_progress_option=config.DEFAULT_PERSIST_PROGRESS_OPTION,
    render_option=config.DEFAULT_RENDER_OPTION,
    eval_eps=config.DEFAULT_EVAL_EPS,
    pause_time=config.DEFAULT_PAUSE_TIME,
    min_steps=config.DEFAULT_MIN_STEPS,
):
    """
    Test model tests agents

    :param min_steps: the minimum steps per episode for evaluation
    :param pause_time: the time paused for preparing screen recording
    :param eval_eps: the number of episode per evaluation
    :param render_option: how the game play should be rendered
    :param persist_progress_option:
    :param checkpoint_location: (not used in testing)
    :param learning_rate: (not used in testing)
    :param verbose: the verbosity level
    :param gamma: (not used in testing)
    :param model_location: used to load the pre-trained model
    :return: None
    """
    env_name = config.DEFAULT_ENV_NAME
    test_env = gym.make(env_name)
    agent = DqnAgent(state_space=test_env.observation_space.shape[0],
                     action_space=test_env.action_space.n,
                     gamma=gamma,
                     verbose=verbose,
                     lr=learning_rate,
                     checkpoint_location=checkpoint_location,
                     model_location=model_location,
                     persist_progress_option=persist_progress_option,
                     mode='test',
                     epsilon=0)
    avg_reward = utils.play_episodes(env=test_env,
                                     policy=agent.random_policy,
                                     render_option=render_option,
                                     num_eps=eval_eps,
                                     pause_time=pause_time,
                                     min_steps=min_steps)
    test_env.close()
    return avg_reward
Esempio n. 6
0
def train_model(max_episodes=100):
    """
    Trains a DQN agent to play the CartPole game by trial and error

    :return: None
    """

    # buffer = ReplayBuffer()
    # Make environment
    env = WhaleEnv(
        config={
            'active_player': 0,
            'seed': datetime.utcnow().microsecond,
            'env_num': 1,
            'num_players': 5
        })
    # Set up agents
    action_num = 3
    agent = DqnAgent(action_num=action_num, player_num=5)
    agent_0 = NoDrawAgent(action_num=action_num)
    agent_1 = NoDrawAgent(action_num=action_num)
    agent_2 = NoDrawAgent(action_num=action_num)
    agent_3 = NoDrawAgent(action_num=action_num)
    # agent_train = RandomAgent(action_num=action_num)
    agents = [agent, agent_0, agent_1, agent_2, agent_3]
    # train_agents = [agent_train, agent_0, agent_1, agent_2, agent_3]
    env.set_agents(agents)
    agent.load_pretrained()
    min_perf, max_perf = 1.0, 0.0
    for episode_cnt in range(1, max_episodes + 1):
        # print(f'{datetime.utcnow()} train ...')
        loss = agent.train(
            collect_gameplay_experiences(env, agents, GAME_COUNT_PER_EPISODE))
        # print(f'{datetime.utcnow()} eval  ...')
        avg_rewards = evaluate_training_result(env, agents,
                                               EVAL_EPISODES_COUNT)
        # print(f'{datetime.utcnow()} calc  ...')
        if avg_rewards[0] > max_perf:
            max_perf = avg_rewards[0]
            agent.save_weight()
        if avg_rewards[0] < min_perf:
            min_perf = avg_rewards[0]
        print('{0:03d}/{1} perf:{2:.2f}(min:{3:.2f} max:{4:.2f})'
              'loss:{5:.4f} rewards:{6:.2f} {7:.2f} {8:.2f} {9:.2f}'.format(
                  episode_cnt, max_episodes, avg_rewards[0], min_perf,
                  max_perf, loss[0], avg_rewards[1], avg_rewards[2],
                  avg_rewards[3], avg_rewards[4]))
    # env.close()
    print('training end')
Esempio n. 7
0
 def post(self, id):
     if not id in agents:
         json_data = request.get_json(force=True)
         print("##################################")
         print(json_data)
         if json_data['a_type'] == "int":
             a_type = np.int32
         if json_data['a_type'] == "float" or json_data[
                 'a_type'] == "double":
             a_type = np.float
         if json_data['o_type'] == "int":
             o_type = np.int32
         if json_data['o_type'] == "float" or json_data[
                 'o_type'] == "double":
             o_type = np.float
         if json_data['agent_type'] == "dqn":
             agent = DqnAgent(
                 array_spec.BoundedArraySpec(shape=json_data['a_shape'],
                                             dtype=a_type,
                                             minimum=json_data['a_min'],
                                             maximum=json_data['a_max'],
                                             name='action'),
                 array_spec.BoundedArraySpec(shape=json_data['o_shape'],
                                             dtype=o_type,
                                             minimum=json_data['o_min'],
                                             maximum=json_data['o_max'],
                                             name='observation'),
                 np.array(json_data['init_state'], dtype=o_type),
                 json_data['parameters'])
         if json_data['agent_type'] == "reinforce":
             agent = ReinforceAgent(
                 array_spec.BoundedArraySpec(shape=json_data['a_shape'],
                                             dtype=a_type,
                                             minimum=json_data['a_min'],
                                             maximum=json_data['a_max'],
                                             name='action'),
                 array_spec.BoundedArraySpec(shape=json_data['o_shape'],
                                             dtype=o_type,
                                             minimum=json_data['o_min'],
                                             maximum=json_data['o_max'],
                                             name='observation'),
                 np.array(json_data['init_state'], dtype=o_type),
                 json_data['parameters'])
         agents[id] = agent
Esempio n. 8
0
def train_model(
    num_iterations=config.DEFAULT_NUM_ITERATIONS,
    batch_size=config.DEFAULT_BATCH_SIZE,
    max_replay_history=config.DEFAULT_MAX_REPLAY_HISTORY,
    gamma=config.DEFAULT_GAMMA,
    eval_eps=config.DEFAULT_EVAL_EPS,
    learning_rate=config.DEFAULT_LEARNING_RATE,
    target_network_update_frequency=config.
    DEFAULT_TARGET_NETWORK_UPDATE_FREQUENCY,
    checkpoint_location=config.DEFAULT_CHECKPOINT_LOCATION,
    model_location=config.DEFAULT_MODEL_LOCATION,
    verbose=config.DEFAULT_VERBOSITY_OPTION,
    visualizer_type=config.DEFAULT_VISUALIZER_TYPE,
    render_option=config.DEFAULT_RENDER_OPTION,
    persist_progress_option=config.DEFAULT_PERSIST_PROGRESS_OPTION,
    epsilon=config.DEFAULT_EPSILON,
):
    """
    Trains a DQN agent by playing episodes of the Cart Pole game

    :param epsilon: epsilon is the probability that a random action is chosen
    :param target_network_update_frequency: how frequent target Q network gets updates
    :param num_iterations: the number of episodes the agent will play
    :param batch_size: the training batch size
    :param max_replay_history: the limit of the replay buffer length
    :param gamma: discount rate
    :param eval_eps: the number of episode per evaluation
    :param learning_rate: the learning rate of the back propagation
    :param checkpoint_location: the location to save the training checkpoints
    :param model_location: the location to save the pre-trained models
    :param verbose: the verbosity level which can be progress, loss, policy and init
    :param visualizer_type: the type of visualization to be used
    :param render_option: if the game play should be rendered
    :param persist_progress_option: if the training progress should be saved

    :return: (maximum average reward, baseline average reward)
    """
    visualizer = get_training_visualizer(visualizer_type=visualizer_type)
    use_epsilon = epsilon
    if visualizer_type == 'streamlit':
        use_epsilon = visualizer.get_ui_feedback()['epsilon']
    env_name = config.DEFAULT_ENV_NAME
    train_env = gym.make(env_name)
    eval_env = gym.make(env_name)
    agent = DqnAgent(state_space=train_env.observation_space.shape[0],
                     action_space=train_env.action_space.n,
                     gamma=gamma,
                     verbose=verbose,
                     lr=learning_rate,
                     checkpoint_location=checkpoint_location,
                     model_location=model_location,
                     persist_progress_option=persist_progress_option,
                     mode='train',
                     epsilon=use_epsilon)
    benchmark_reward = compute_avg_reward(eval_env, agent.random_policy,
                                          eval_eps)
    buffer = DqnReplayBuffer(max_size=max_replay_history)
    max_avg_reward = 0.0
    for eps_cnt in range(num_iterations):
        collect_episode(train_env, agent.collect_policy, buffer, render_option)
        if buffer.can_sample_batch(batch_size):
            state_batch, next_state_batch, action_batch, reward_batch, done_batch = \
                buffer.sample_batch(batch_size=batch_size)
            loss = agent.train(state_batch=state_batch,
                               next_state_batch=next_state_batch,
                               action_batch=action_batch,
                               reward_batch=reward_batch,
                               done_batch=done_batch,
                               batch_size=batch_size)
            visualizer.log_loss(loss=loss)
            use_eval_eps = eval_eps
            if visualizer_type == 'streamlit':
                use_eval_eps = visualizer.get_ui_feedback()['eval_eps']
            avg_reward = compute_avg_reward(eval_env,
                                            agent.policy,
                                            num_episodes=use_eval_eps)
            visualizer.log_reward(reward=[avg_reward])
            if avg_reward > max_avg_reward:
                max_avg_reward = avg_reward
                if persist_progress_option == 'all':
                    agent.save_model()
            if verbose != 'none':
                print(
                    'Episode {0}/{1}({2}%) finished with avg reward {3} w/ benchmark reward {4}'
                    ' and buffer volume {5}'.format(
                        eps_cnt, num_iterations,
                        round(eps_cnt / num_iterations * 100.0, 2), avg_reward,
                        benchmark_reward, buffer.get_volume()))
        else:
            if verbose != 'none':
                print('Not enough sample, skipping...')
        used_target_network_update_frequency = target_network_update_frequency
        if visualizer_type == 'streamlit':
            used_target_network_update_frequency = visualizer.get_ui_feedback(
            )['update_freq']
        if eps_cnt % used_target_network_update_frequency == 0:
            agent.update_target_network()
    train_env.close()
    eval_env.close()
    return max_avg_reward, benchmark_reward
Esempio n. 9
0
class Player:
    def __init__(self, name, type_):
        self.score = 0
        self.name = name
        self.hand = None
        self.mask = None
        self.type = type_

        if self.type == PlayerType.Computer:
            self.agent = DqnAgent(train=False)

    def move(self, deck, top_card, verbose):
        if self.type == PlayerType.Human:
            self.display_hand()
            print('{}, Please Choose a # from the following options'.format(
                self.name))
            print('\t 1: Draw a card from the deck')
            print('\t 2: Pick up a {} from the pile'.format(top_card.name))
            selection = input('Selection: ')
            if selection == '1':
                card = deck.draw()
                print('You drew a {}'.format(card.name))
            else:
                card = top_card
            print('What will you do with this {}?'.format(card.name))
            print(
                'Enter a Number (1-6) to replace that card, or 0 to put it back on the pile'
            )
            selection = input('Selection: ')
            if selection == '0':
                return card
            else:
                position = int(selection) - 1
                old_card = self.hand[position]
                self.hand[position] = card
                self.mask[position] = 1
                return old_card
        else:

            # Pick up card from pile if it 0s a column or is <= 6
            draw = True
            for column in range(3):
                card1 = self.hand[column]
                card2 = self.hand[3 + column]
                if card1 != card2 and card1 != Card.Joker and card2 != Card.Joker and (
                        top_card == card1 or top_card == card2):
                    draw = False
            if draw and top_card.value[1] <= 6:
                draw = False

            if draw:
                card = deck.draw()
                if verbose:
                    print('\t{} drew a {} from the deck.'.format(
                        self.name, card.name))
            else:
                card = top_card
                if verbose:
                    print('\t{} picked up a {} from the pile.'.format(
                        self.name, card.name))

            # Now decide what to do with that card
            state = self._get_dqn_state(top_card)
            action = self.agent.policy(state)

            if action == 0:
                if verbose:
                    print('\t{} returned the {} to the pile.'.format(
                        self.name, card.name))
                    self.display_hand()
                return card
            else:
                old_card = self.hand[action - 1]
                self.hand[action - 1] = card
                if verbose:
                    print('\t{} replaced the {}{} at position {} with the {}.'.
                          format(self.name, old_card.name,
                                 '' if self.mask[action - 1] else '(hidden)',
                                 action, card.name))
                    self.display_hand()
                self.mask[action - 1] = 1
                return old_card

    def _get_dqn_state(self, top_card):
        state = [top_card.value[1]]
        for i, card in enumerate(self.hand):
            if self.mask[i]:
                state.append(card.value[1])
            else:
                state.append(-1)
        return np.array(state)

    def flip_cards(self):
        if self.type == PlayerType.Human:
            selection = input(
                '{}, Please Select Which Cards to flip, left-to-right beginning at 1 (i.e. \'1, 4\'): '
                .format(self.name)).replace(' ', '').split(',')
            self.mask[int(selection[0]) - 1] = 1
            self.mask[int(selection[1]) - 1] = 1
        else:
            self.mask[0] = 1
            self.mask[1] = 1

    def new_hand(self, cards):
        self.hand = cards
        self.mask = [0 for _ in range(len(self.hand))]

    def is_done(self):
        return all(self.mask)

    def get_hand_score(self):
        hand_score = 0
        if not self.hand:
            return None
        for column in range(3):
            card1 = self.hand[column]
            card2 = self.hand[3 + column]
            if card1 != card2 and card1 != Card.Joker and card2 != Card.Joker:
                hand_score += card1.value[0] + card2.value[0]
        return hand_score

    def display_hand(self):
        s = ''
        for i in range(len(self.hand)):
            if self.mask[i]:
                s += self.hand[i].name
                for _ in range(6 - len(self.hand[i].name)):
                    s += ' '
            else:
                s += '  ?   '
            if i == 2:
                print(s + '\n')
                s = ''
        print(s + '\n')
Esempio n. 10
0
fc_param_1 = {"units": [320, 160, 64, 32, 20],
            "bias": [True],
            "activation": ['ReLU', 'ReLU', 'ReLU', 'ReLU', 'ReLU']}

fc_param_2 = {"units": [40,20], "bias": [False], "activation": ['ReLU']}

network = DqnNet(12, 10,
                 conv_network_param=None,
                 fc_network_param=fc_param_2)

loss_fn = nn.functional.smooth_l1_loss

agent = DqnAgent(environment=env,
                 preprocessor=process_from_replay_sample,
                 network=network,
                 batch_size=batch_size,
                 loss_fn=loss_fn,
                 optimizer='Adam',
                 discount=0.9,
                 update_tau=0.5,
                 update_period=30,
                 learning_rate=1e-3,
                 eps_greedy=0.98,
                 eps_decay_count=200000,
                 eps_minimum=0.1)

raw_sample = replay_memory.sample(batch_size)

#agent.train_step(sample)
Esempio n. 11
0
ACTION_INTERVAL = 4
INITIAL_ACTION_SKIPS = 30
TOTAL_ACTION_NUM = 1000000

SAVE_PATH = "./breakout_ckpt/"

env = gym.make('Breakout-v0')
observation = env.reset()
frame_preprocessor = FramePreProcessor(RESIZE_WIDTH, RESIZE_HEIGHT, FRAME_STACK_NUM)

session = tf.Session()
state_space = len(observation)
action_space = env.action_space.n
model = AtariNetwork(session, RESIZE_WIDTH, RESIZE_HEIGHT, FRAME_STACK_NUM, action_space)
target = AtariNetwork(session, RESIZE_WIDTH, RESIZE_HEIGHT, FRAME_STACK_NUM, action_space)
agent = DqnAgent(session, action_space, model, target, ReplayMemory(REPLAY_MEMORY_SIZE, BATCH_SIZE),
                 TRAIN_INTERVAL, TARGET_UPDATE_INTERVAL, ACTION_INTERVAL, TOTAL_ACTION_NUM)

TOTAL_EPISODES = 50000
saver = tf.train.Saver(agent.get_tf_variables())

tick = 0
with open(SAVE_PATH + 'log.txt', 'w') as log:
    for ep in range(TOTAL_EPISODES):
        if ep % 1000 == 0:
            saver.save(session, save_path = SAVE_PATH +str(ep))

        observation = env.reset()
        for _ in range(random.randint(1, INITIAL_ACTION_SKIPS)):
            prev_observation = observation
            observation, _, _, _ = env.step(0)
Esempio n. 12
0
LEARNING_RATE = 0.001
DISCOUNT_RATE = 0.99

env = gym.make('Acrobot-v1')
env.reset()
total_episode = 3000

observation = env.reset()
session = tf.Session()
state_space = len(observation)
action_space = env.action_space.n
model = AcrobatNn(session, state_space, action_space, LEARNING_RATE,
                  DISCOUNT_RATE)
target = AcrobatNn(session, state_space, action_space, LEARNING_RATE,
                   DISCOUNT_RATE)
agent = DqnAgent(session, action_space, model, target, ReplayMemory())
step = 0
saver = tf.train.Saver(agent.get_tf_variables())
#saver.restore(session, save_path='acrobat_ckpt/150')
for ep in range(total_episode):
    if ep % 50 == 0:
        saver.save(session, save_path="./acrobat_ckpt/" + str(ep))

    observation = env.reset()
    agent.begin_episode(observation)
    for t in range(21000):
        step += 1
        action = agent.get_action()
        observation, reward, done, info = env.step(action)
        if ep % 50 == 0 and t < 300:
            #env.render()
Esempio n. 13
0
args.start_timesteps = 10
args.buffer_size = 5000
args.num_steps = 2000000
args.experiment = 's256x128_bs1024_adam2e6_per_normal_nonstep'
args.load = True
args.train = True

env = gym.make(args.environment)
args.action_dim = env.action_space.n
args.state_dim = env.observation_space.shape[0]

dqn = DqnAgent(args.state_dim,
               args.action_dim,
               args.buffer_size,
               args.batch_size,
               args.experiment,
               256,
               128,
               priorized_exp=True,
               gn=False)

base_dir = os.getcwd(
) + '/models/' + args.environment + '_' + args.experiment + '/'
run_number = 0
while os.path.exists(base_dir + str(run_number)):
    run_number += 1

if args.load: dqn.restore_training(base_dir + 'training/')

os.makedirs(base_dir + str(run_number))
Esempio n. 14
0
from tf_agents.specs import array_spec

from generic_environment import GenericEnv
from dqn_agent import DqnAgent

#params
num_episode = 2000  # @param
board_size = 9

#env
dqn = DqnAgent(
    array_spec.BoundedArraySpec(shape=(),
                                dtype=np.int32,
                                minimum=0,
                                maximum=3,
                                name='action'),
    array_spec.BoundedArraySpec(shape=(2, ),
                                dtype=np.int32,
                                minimum=0,
                                maximum=board_size,
                                name='observation'),
    np.array([0, 0], dtype=np.int32))

#[row, column]
state = np.array([0, 0], dtype=np.int32)

episode_count = 0
step_count = 0
while episode_count < num_episode:
    if episode_count < 1000:
        action_step = dqn.get_train_action()
    else:
Esempio n. 15
0
 def __init__(self, environment, memory, action_size):
     self._environment = environment
     self._memory = memory
     batchHelper = BatchHelper(memory, DqnGlobals.BATCH_SIZE, action_size)
     self._agent = DqnAgent(action_size, batchHelper)
     self._gifSaver = GifSaver(memory, self._agent)
Esempio n. 16
0
class EpisodeManager:
    def __init__(self, environment, memory, action_size):
        self._environment = environment
        self._memory = memory
        batchHelper = BatchHelper(memory, DqnGlobals.BATCH_SIZE, action_size)
        self._agent = DqnAgent(action_size, batchHelper)
        self._gifSaver = GifSaver(memory, self._agent)

    def ShouldStop(self):
        return os.path.isfile("StopTraining.txt")

    def Run(self):
        scoreHistory = []
        while (self.ShouldStop() == False):
            startTime = time.time()
            score, steps = self.RunOneEpisode()
            elapsedTime = time.time() - startTime
            self._agent.OnGameOver(steps)
            self._gifSaver.OnEpisodeOver()
            totalSteps = self._agent.total_step_count
            episode = self._agent.total_episodes
            scoreHistory.append(score)
            print(
                f"Episode: {episode};  Score: {score};  Steps: {steps}; Time: {elapsedTime:.2f}"
            )
            if (len(scoreHistory) == 10):
                avgScore = np.mean(scoreHistory)
                print(
                    f"Episode: {episode};  Average Score: {avgScore};  Total Steps:  {totalSteps}"
                )
                scoreHistory.clear()
        self._agent.OnExit()

    def OnNextEpisode(self):
        self._environment.reset()
        info = None
        for _ in range(
                np.random.randint(DqnGlobals.FRAMES_PER_STATE,
                                  DqnGlobals.MAX_NOOP)):
            frame, _, done, info = self.NextStep(self._agent.GetFireAction())
            self._memory.AddFrame(frame)
        return info

    def NextStep(self, action):
        rawFrame, reward, done, info = self._environment.step(action)
        processedFrame = ImagePreProcessor.Preprocess(rawFrame)
        return processedFrame, reward, done, info

    def RunOneEpisode(self):
        info = self.OnNextEpisode()
        done = False
        stepCount = 0
        score = 0
        livesLeft = info['ale.lives']
        while not done:
            action = self._agent.GetAction()
            frame, reward, done, info = self.NextStep(action)
            score += reward
            if (info['ale.lives'] < livesLeft):
                reward = -1
                livesLeft = info['ale.lives']
            self._memory.AddMemory(frame, action, reward, done)
            self._agent.Replay()
            stepCount += 1
        return score, stepCount