Пример #1
0
    def prepare_batch(self, batch):
        """
        Transposes and pre-processes batch of transitions into batches of torch tensors
            batch: list of transitions [[s, a, r, s2, done],
                                        [s, a, r, s2, done]]

        Returns: [s], [a], [r], [s2], [done_mask]
        """
        states, actions, rewards, next_states, done_mask = [], [], [], [], []
        for state, action, reward, next_state, done in batch:
            states.append(process_state(state))
            actions.append(action)
            rewards.append(reward)
            next_states.append(process_state(next_state))
            done_mask.append(1 - done)  # turn True values into zero for mask
        states = torch.cat(states)
        next_states = torch.cat(next_states)
        rewards = torch.FloatTensor(rewards)
        done_mask = torch.FloatTensor(done_mask)
        return states, actions, rewards, next_states, done_mask
Пример #2
0
    def select_action(self, state, epsilon):
        """
        epsilon greedy policy.
        selects action corresponding to maximum predicted Q value, otherwise selects
        otherwise selects random action with epsilon probability.
        Args:
            state: current state of the environment (4 stack of image frames)
            epsilon: probability of random action (1.0 - 0.0)

        Returns: action
        """
        if epsilon > random.random():
            return self.environment.action_space.sample()
        state = Variable(process_state(state), volatile=True).cuda()
        return int(self.q_network(state).data.max(1)[1])
Пример #3
0
def get_fixed_states():

    fixed_states = []

    env = gym.make('BreakoutNoFrameskip-v0')

    cumulative_screenshot = []

    def prepare_cumulative_screenshot(cumul_screenshot):
        # Prepare the cumulative screenshot
        padding_image = torch.zeros((1, constants.STATE_IMG_HEIGHT, constants.STATE_IMG_WIDTH))
        for i in range(constants.N_IMAGES_PER_STATE - 1):
            cumul_screenshot.append(padding_image)

        screen_grayscale_state = get_screen(env)
        cumul_screenshot.append(screen_grayscale_state)

    prepare_cumulative_screenshot(cumulative_screenshot)
    env.reset()

    for steps in range(constants.N_STEPS_FIXED_STATES + 8):
        if constants.SHOW_SCREEN:
            env.render()

        _, _, done, _ = env.step(env.action_space.sample())  # take a random action

        if done:
            env.reset()
            cumulative_screenshot = []
            prepare_cumulative_screenshot(cumulative_screenshot)

        screen_grayscale = get_screen(env)
        cumulative_screenshot.append(screen_grayscale)
        cumulative_screenshot.pop(0)
        state = utils.process_state(cumulative_screenshot)

        if steps >= 8:
            fixed_states.append(state)

    env.close()
    return fixed_states
Пример #4
0
    def learn(self):
        for actor in self.actors:
            actor.train()
        self.critic.train()

        reward_per_episode = []
        reward_per_episode_early_stop = []
        step_per_episode = []
        entropy_per_episode = []
        critic_loss_per_episode = []
        for episode in range(1, self.n_episodes + 1):
            early_stop = False
            # rollout
            state = self.env.reset()
            state_per_step = []
            log_prob_per_step = [[] for _ in range(self.n_players)]
            entropy_per_step = [[] for _ in range(self.n_players)]
            reward_per_step = []
            for step in range(1, self.episode_max_length + 1):
                state = process_state(state)
                state = torch.tensor([state], device=self.device)
                state_per_step.append(state)
                actions_per_player = []
                for idx, actor in enumerate(self.actors):
                    dist = actor(state)
                    action = dist.sample()  # tensor(14)
                    actions_per_player.append(
                        self.action_list[action[0].item()])
                    log_prob = dist.log_prob(action)  # tensor([-2.1])
                    entropy = dist.entropy()  # tensor([2.8]) # TODO
                    log_prob_per_step[idx].append(log_prob)
                    entropy_per_step[idx].append(entropy)
                # for action in actions_per_player:
                #     print(football_action_set.named_action_from_action_set(self.env.unwrapped._env._action_set, action))
                obs, rew, done, info = self.env.step(actions_per_player)
                rew = rew[0]
                reward_per_step.append(torch.tensor([rew], device=self.device))
                state = obs
                # 如果变成对方持球,就强行停止
                if self.args.early_stop:
                    if state[0, 96] == 1:
                        rew = -1
                        early_stop = True
                        break
                if done:
                    break
            reward_per_episode_early_stop.append(rew)
            if early_stop:
                reward_per_episode.append(0)
            else:
                reward_per_episode.append(rew)

            step_per_episode.append(step)

            # update
            state_per_step = torch.cat(state_per_step)
            log_prob_per_step = [
                torch.cat(log_prob_per_player)
                for log_prob_per_player in log_prob_per_step
            ]
            entropy_per_step = [
                torch.cat(entropy_per_player)
                for entropy_per_player in entropy_per_step
            ]
            entropy_per_episode.append([
                entropy_per_player.mean().item()
                for entropy_per_player in entropy_per_step
            ])

            returns = self.calculate_returns(last_state_value=0,
                                             rewards=reward_per_step)
            values = self.critic(state_per_step).squeeze(-1)

            advantages = returns - values
            critic_loss = F.mse_loss(input=values,
                                     target=returns.detach(),
                                     reduction="mean")
            critic_loss_per_episode.append(critic_loss.item())
            critic_loss.backward()
            self.critic_optimizer.step()
            self.critic_optimizer.zero_grad()

            for idx, actor in enumerate(self.actors):
                actor_loss = -log_prob_per_step[idx] * advantages.detach()
                actor_loss = actor_loss.mean()
                actor_loss.backward()
                self.actor_optimizers[idx].step()
                self.actor_optimizers[idx].zero_grad()

            # log

            moving_average_window_size = len(reward_per_episode[-self.k:])
            moving_average_reward = sum(
                reward_per_episode[-self.k:]) / moving_average_window_size
            moving_average_reward_early_stop = sum(
                reward_per_episode_early_stop[-self.k:]
            ) / moving_average_window_size

            moving_average_step = sum(
                step_per_episode[-self.k:]) / moving_average_window_size

            moving_average_entropy_per_player = [
                sum(e) / moving_average_window_size
                for e in zip(*entropy_per_episode[-self.k:])
            ]
            moving_average_critic_loss = sum(
                critic_loss_per_episode[-self.k:]) / moving_average_window_size
            print(
                f"episode: {episode}, reward: {reward_per_episode[-1]}, step: {step_per_episode[-1]}, moving average reward: {moving_average_reward}, early stop: {moving_average_reward_early_stop}, moving average step: {moving_average_step}, moving average critic loss: {moving_average_critic_loss}, moving average entropy: {moving_average_entropy_per_player}"
            )

            self.writer.add_scalar("moving average reward",
                                   moving_average_reward, episode)
            self.writer.add_scalar("moving average reward early stop",
                                   moving_average_reward_early_stop, episode)

            self.writer.add_scalar("moving average step", moving_average_step,
                                   episode)
            self.writer.add_scalar("moving average critic loss",
                                   moving_average_critic_loss, episode)
            for idx in range(self.n_players):
                self.writer.add_scalar(f"moving average entropy of {idx}",
                                       moving_average_entropy_per_player[idx],
                                       episode)
Пример #5
0
def train_dqn(settings):
    required_settings = [
        "batch_size",
        "checkpoint_frequency",
        "device",
        "eps_start",
        "eps_end",
        "eps_cliff",
        "eps_decay",
        "gamma",
        "log_freq",
        "logs_dir",
        "lr",
        "max_steps",
        "memory_size",
        "model_name",
        "num_episodes",
        "out_dir",
        "target_net_update_freq",
    ]
    if not settings_is_valid(settings, required_settings):
        raise Exception(
            f"Settings object {settings} missing some required settings.")

    batch_size = settings["batch_size"]
    checkpoint_frequency = settings["checkpoint_frequency"]
    device = settings["device"]
    eps_start = settings["eps_start"]
    eps_end = settings["eps_end"]
    eps_cliff = settings["eps_cliff"]
    # eps_decay = settings["eps_decay"]
    gamma = settings["gamma"]
    logs_dir = settings["logs_dir"]
    log_freq = settings["log_freq"]
    lr = settings["lr"]
    max_steps = settings["max_steps"]
    memory_size = settings["memory_size"]
    model_name = settings["model_name"]
    num_episodes = settings["num_episodes"]
    out_dir = settings["out_dir"]
    target_net_update_freq = settings["target_net_update_freq"]

    # Initialize environment
    env = gym.make("StarGunner-v0")

    # Initialize model
    num_actions = env.action_space.n
    settings["num_actions"] = num_actions
    policy_net = DQN(settings).to(device)
    target_net = DQN(settings).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    # Initialize memory
    logging.info("Initializing memory.")
    memory = ReplayMemory(memory_size)
    memory.init_with_random((1, 3, 84, 84), num_actions)
    logging.info("Finished initializing memory.")

    # Initialize other model ingredients
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)

    # Initialize tensorboard
    writer = SummaryWriter(logs_dir)

    # Loop over episodes
    policy_net.train()
    steps_done = 0
    log_reward_acc = 0.0
    log_steps_acc = 0
    for episode in tqdm(range(num_episodes)):
        state = process_state(env.reset()).to(device)
        reward_acc = 0.0
        loss_acc = 0.0

        # Loop over steps in episode
        for t in range(max_steps):
            with torch.no_grad():
                Q = policy_net.forward(state.type(torch.float))

            # Get best predicted action and perform it
            if steps_done < eps_cliff:
                epsilon = -(eps_start -
                            eps_end) / eps_cliff * steps_done + eps_start
            else:
                epsilon = eps_end

            if random.random() < epsilon:
                predicted_action = torch.tensor([env.action_space.sample()
                                                 ]).to(device)
            else:
                predicted_action = torch.argmax(Q, dim=1)
            next_state, raw_reward, done, info = env.step(
                predicted_action.item())
            # Note that next state could also be a difference
            next_state = process_state(next_state)
            reward = torch.tensor([clamp_reward(raw_reward)])

            # Save to memory
            memory.push(state.to("cpu"), predicted_action.to("cpu"),
                        next_state, reward)

            # Move to next state
            state = next_state.to(device)

            # Sample from memory
            batch = Transition(*zip(*memory.sample(batch_size)))

            # Mask terminal state (adapted from https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
            final_mask = torch.tensor(
                tuple(map(lambda s: s is not None, batch.next_state)),
                device=device,
                dtype=torch.bool,
            )
            # print("FINAL_MASK", final_mask.shape)
            state_batch = torch.cat(batch.state).type(torch.float).to(device)
            next_state_batch = torch.cat(batch.next_state).type(
                torch.float).to(device)
            action_batch = torch.cat(batch.action).to(device)
            reward_batch = torch.cat(batch.reward).to(device)

            # print("STATE_BATCH SHAPE", state_batch.shape)
            # print("STATE_BATCH", state_batch[4, :, 100])
            # print("ACTION_BATCH SHAPE", action_batch.shape)
            # print("ACTION_BATCH", action_batch)
            # print("REWARD_BATCH SHAPE", reward_batch.shape)

            # Compute Q
            # Q_next = torch.zeros((batch_size, num_actions))
            # print("MODEL STATE BATCH SHAPE", model(state_batch).shape)
            Q_actual = policy_net(state_batch).gather(
                1, action_batch.view(action_batch.shape[0], 1))
            Q_next_pred = target_net(next_state_batch)
            Q_max = torch.max(Q_next_pred, dim=1)[0].detach()
            # print("Q_MAX shape", Q_max.shape)
            target = reward_batch + gamma * Q_max * final_mask.to(Q_max.dtype)
            # print("TARGET SIZE", target.shape)

            # Calculate loss
            loss = F.smooth_l1_loss(Q_actual, target.unsqueeze(1))
            optimizer.zero_grad()
            loss.backward()

            # Clamp gradient to avoid gradient explosion
            for param in policy_net.parameters():
                param.grad.data.clamp_(-1, 1)
            optimizer.step()

            # Store stats
            loss_acc += loss.item()
            reward_acc += raw_reward
            steps_done += 1

            if steps_done % target_net_update_freq == 0:
                target_net.load_state_dict(policy_net.state_dict())

            # Exit if in terminal state
            if done:
                logging.debug(
                    f"Episode {episode} finished after {t} timesteps with reward {reward_acc}."
                )
                break

        logging.debug(f"Loss: {loss_acc / t}")

        # Save model checkpoint
        if (episode != 0) and (episode % checkpoint_frequency == 0):
            save_model_checkpoint(
                policy_net,
                optimizer,
                episode,
                loss,
                f"{out_dir}/checkpoints/{model_name}_{episode}",
            )

        # Log to tensorboard
        log_reward_acc += reward_acc
        log_steps_acc += t
        writer.add_scalar("Loss / Timestep", loss_acc / t, episode)
        if episode % log_freq == 0:
            writer.add_scalar("Reward", log_reward_acc / log_freq, episode)
            writer.add_scalar("Reward / Timestep",
                              log_reward_acc / log_steps_acc, episode)
            writer.add_scalar("Duration", log_steps_acc / log_freq, episode)
            writer.add_scalar("Steps", log_reward_acc / log_steps_acc,
                              steps_done)
            log_reward_acc = 0.0
            log_steps_acc = 0

    # Save model
    save_model(policy_net, f"{out_dir}/{model_name}.model")

    # Report final stats
    logging.info(f"Steps Done: {steps_done}")

    env.close()
    return policy_net
Пример #6
0
    def learn(self):
        for actor in self.actors:
            actor.train()
        self.critic.train()

        # collect experience
        print("collecting experience")
        for _ in range(self.args.episodes_before_training):
            # rollout
            state = self.env.reset()
            state_per_step = []
            action_per_step = []
            reward_per_step = []
            for step in range(1, self.episode_max_length + 1):
                state = process_state(state)
                state = torch.tensor([state], device=self.device)
                state_per_step.append(state)
                actions = np.random.randint(0,
                                            len(self.action_list),
                                            size=self.n_players)

                obs, rew, done, info = self.env.step(
                    [self.action_list[action] for action in actions])
                action_per_step.append(
                    torch.tensor(actions, device=self.device))
                rew = rew[0]
                reward_per_step.append(torch.tensor([rew], device=self.device))
                state = obs
                # 如果变成对方持球,就强行停止
                if self.args.early_stop:
                    if state[0, 96] == 1:
                        rew = -1
                        break
                if done:
                    break

            # add to memory

            returns = self.calculate_returns(last_state_value=0,
                                             rewards=reward_per_step)
            self.replay_memory.add_experience(
                states=torch.cat(state_per_step),
                actions=torch.cat(action_per_step).view(-1, self.n_players),
                returns=returns,
                reward_of_episode=rew)
        # training
        print("training")
        reward_per_episode = []
        reward_per_episode_without_early_stop = []
        step_per_episode = []
        entropy_per_episode = []
        critic_loss_per_episode = []
        for episode in range(1, self.n_episodes + 1):
            early_stop = False
            # rollout
            state = self.env.reset()
            state_per_step = []
            action_per_step = []
            reward_per_step = []
            for step in range(1, self.episode_max_length + 1):
                state = process_state(state)
                state = torch.tensor([state], device=self.device)
                state_per_step.append(state)
                actions_per_player = []
                for idx, actor in enumerate(self.actors):
                    dist = actor(state)
                    action = dist.sample()  # tensor(14)
                    action_per_step.append(action)
                    actions_per_player.append(
                        self.action_list[action[0].item()])
                # for action in actions_per_player:
                #     print(football_action_set.named_action_from_action_set(self.env.unwrapped._env._action_set, action))
                obs, rew, done, info = self.env.step(actions_per_player)
                rew = rew[0]
                reward_per_step.append(torch.tensor([rew], device=self.device))
                state = obs
                # 如果变成对方持球,就强行停止
                if self.args.early_stop:
                    if state[0, 96] == 1:
                        rew = -1
                        early_stop = True
                        break
                if done:
                    break
            reward_per_episode.append(rew)
            if early_stop:
                reward_per_episode_without_early_stop.append(0)
            else:
                reward_per_episode_without_early_stop.append(rew)

            step_per_episode.append(step)

            # add to memory

            returns = self.calculate_returns(last_state_value=0,
                                             rewards=reward_per_step)
            self.replay_memory.add_experience(
                states=torch.cat(state_per_step),
                actions=torch.cat(action_per_step).view(-1, self.n_players),
                returns=returns,
                reward_of_episode=rew)

            # update
            batch_states, batch_actions, batch_returns = self.replay_memory.sample_minibatch(
                self.args.batch_size)

            values = self.critic(batch_states).squeeze(-1)

            advantages = batch_returns - values
            critic_loss = F.mse_loss(input=values,
                                     target=batch_returns.detach(),
                                     reduction="mean")
            critic_loss_per_episode.append(critic_loss.item())
            critic_loss.backward()
            self.critic_optimizer.step()
            self.critic_optimizer.zero_grad()

            entropy_per_player = []
            for idx, actor in enumerate(self.actors):
                dist = actor(batch_states)
                log_prob = dist.log_prob(batch_actions[:, idx])
                entropy = dist.entropy()
                entropy_per_player.append(entropy.mean().item())

                actor_loss = -log_prob * advantages.detach()
                actor_loss = actor_loss.mean()
                actor_loss.backward()
                self.actor_optimizers[idx].step()
                self.actor_optimizers[idx].zero_grad()
            entropy_per_episode.append(entropy_per_player)

            # log

            moving_average_window_size = len(reward_per_episode[-self.k:])
            moving_average_reward = sum(
                reward_per_episode[-self.k:]) / moving_average_window_size
            moving_average_reward_without_early_stop = sum(
                reward_per_episode_without_early_stop[-self.k:]
            ) / moving_average_window_size

            moving_average_step = sum(
                step_per_episode[-self.k:]) / moving_average_window_size

            moving_average_entropy_per_player = [
                sum(e) / moving_average_window_size
                for e in zip(*entropy_per_episode[-self.k:])
            ]
            moving_average_critic_loss = sum(
                critic_loss_per_episode[-self.k:]) / moving_average_window_size
            print(
                f"episode: {episode}, reward: {reward_per_episode[-1]}, step: {step_per_episode[-1]}, moving average reward: {moving_average_reward}, without early stop: {moving_average_reward_without_early_stop}, moving average step: {moving_average_step}, moving average critic loss: {moving_average_critic_loss}, moving average entropy: {moving_average_entropy_per_player}"
            )

            self.writer.add_scalar("moving average reward",
                                   moving_average_reward, episode)
            self.writer.add_scalar("moving average reward without early stop",
                                   moving_average_reward_without_early_stop,
                                   episode)

            self.writer.add_scalar("moving average step", moving_average_step,
                                   episode)
            self.writer.add_scalar("moving average critic loss",
                                   moving_average_critic_loss, episode)
            for idx in range(self.n_players):
                self.writer.add_scalar(f"moving average entropy of {idx}",
                                       moving_average_entropy_per_player[idx],
                                       episode)
Пример #7
0
def test_agent(target_nn, fixed_states):
    env = gym.make('BreakoutNoFrameskip-v0')

    steps = 0
    n_episodes = 0
    sum_score = 0
    sum_reward = 0
    sum_score_episode = 0
    sum_reward_episode = 0

    done_last_episode = False

    while steps <= constants.N_TEST_STEPS:
        cumulative_screenshot = []

        sum_score_episode = 0
        sum_reward_episode = 0

        # Prepare the cumulative screenshot
        padding_image = torch.zeros((1, constants.STATE_IMG_HEIGHT, constants.STATE_IMG_WIDTH))
        for i in range(constants.N_IMAGES_PER_STATE - 1):
            cumulative_screenshot.append(padding_image)

        env.reset()

        screen_grayscale_state = get_screen(env)
        cumulative_screenshot.append(screen_grayscale_state)

        state = utils.process_state(cumulative_screenshot)

        prev_state_lives = constants.INITIAL_LIVES

        while steps <= constants.N_TEST_STEPS:
            action = select_action(state, target_nn, env)
            _, reward, done, info = env.step(action)

            sum_score_episode += reward

            reward_tensor = None

            if info["ale.lives"] < prev_state_lives:
                sum_reward_episode += -1
            elif reward < 0:
                sum_reward_episode += -1
            elif reward > 0:
                sum_reward_episode += 1

            prev_state_lives = info["ale.lives"]

            screen_grayscale = get_screen(env)
            cumulative_screenshot.append(screen_grayscale)
            cumulative_screenshot.pop(0)

            if done:
                next_state = None
            else:
                next_state = utils.process_state(cumulative_screenshot)

            if next_state is not None:
                state.copy_(next_state)
            steps += 1
            done_last_episode = done

            if done:
                break

        if done_last_episode:
            sum_score += sum_score_episode
            sum_reward += sum_reward_episode
            n_episodes += 1

    env.close()

    if n_episodes == 0:
        n_episodes = 1
        sum_score = sum_score_episode
        sum_reward = sum_reward_episode

    # Compute Q-values
    sum_q_values = 0
    for state in fixed_states:
        sum_q_values += target_nn(state).max(1)[0]

    return sum_reward / n_episodes, sum_score / n_episodes, n_episodes, sum_q_values.item() / len(fixed_states)
Пример #8
0
    def _get_experiences(self, difficulty, ob):
        states = []
        poses = []
        actions = []
        rewards = []
        values = []
        subsequences = []

        if not self.render:
            os.environ["SDL_VIDEODRIVER"] = "dummy"
        # ob_num = int(round(np.random.normal(MAX_OB * round(float(self.global_step.eval()) / NUM_GLOBAL_STEPS), 1)))
        INPUT_DICT['ob'] = ob if ob >= 0 else 0
        difficulty = (difficulty < 1) and 1 or difficulty

        map_info = Env.random_scene(MAP_SIZE,
                                    INPUT_DICT,
                                    difficulty=difficulty)

        subsequence = get_subsequence(self.env, map_info,
                                      self.num_local_steps + SUB_SEQ_DIM)

        state_org = self.env.reset(map_info, STATIC)
        state, s_t_goal, pos = process_state(state_org, need_goal=True)
        state = np.concatenate((state, ) * FRAME, axis=-1)
        state = np.concatenate([state, s_t_goal], axis=-1)
        # record summary
        episode_reward = 0.0
        episode_length = 0
        tstart = time.time()
        terminal = False

        print ""
        for ind in range(self.num_local_steps):
            print "step: " + str(ind) + " ob: " + str(
                INPUT_DICT['ob']) + " difficulty: " + str(difficulty)
            # reward = 0.0
            action, value = self.local_network.sample_action(
                state, pos, subsequence[ind:ind + SUB_SEQ_DIM])
            '''
            if AUX_REWARD:
                try:
                    instruct_action = AStarBlock(state_org)['act_seq'][0]
                except (KeyError, NotImplementedError, RuntimeError):
                    instruct_action = 0
                reward += reward_mapping(action, instruct_action)
            '''
            state_org, reward_tmp, terminal, _ = self.env.step([action])
            # reward += reward_tmp

            # Store this experience.
            states.append(state)
            poses.append(pos)
            actions.append(action)
            # rewards.append(reward)
            rewards.append(reward_tmp)
            values.append(value)
            subsequences.append(subsequence[ind:ind + SUB_SEQ_DIM])

            s_t2, pos = process_state(state_org, need_goal=False)

            state = np.insert(state, -1, s_t2.squeeze(), axis=-1)[:, :, 1:]

            episode_reward += reward_tmp
            episode_length += 1

            if terminal:
                print "Wonderful Path!"
                break

        run_time = time.time() - tstart
        LOGGER.info('Finished episode. Total reward: %d. Length: %d.',
                    episode_reward, episode_length)
        summary = tf.Summary()
        summary.value.add(tag='environment/episode_length',
                          simple_value=episode_length)
        summary.value.add(tag='environment/episode_reward',
                          simple_value=episode_reward)
        summary.value.add(tag='environment/fps',
                          simple_value=episode_length / run_time)

        self.summary_writer.add_summary(summary, self.global_step.eval())
        self.summary_writer.flush()

        # Estimate discounted rewards.
        rewards = np.array(rewards)
        next_value = 0 if terminal else self.local_network.estimate_value(
            state, pos, subsequence[ind:ind + SUB_SEQ_DIM])
        discounted_rewards = _apply_discount(np.append(rewards, next_value),
                                             self.discount)[:-1]

        # Estimate advantages.
        values = np.array(values + [next_value])
        advantages = _apply_discount(
            rewards + self.discount * values[1:] - values[:-1], self.discount)
        return np.array(states), np.array(poses), np.array(actions), advantages,\
               discounted_rewards, sum(rewards), np.array(subsequences)
Пример #9
0
def main_training_loop():

    fixed_states = test.get_fixed_states()
    env = gym.make('BreakoutNoFrameskip-v0')

    n_actions = env.action_space.n

    policy_net = DeepQNetwork(constants.STATE_IMG_HEIGHT,
                              constants.STATE_IMG_WIDTH,
                              constants.N_IMAGES_PER_STATE, n_actions)

    target_net = DeepQNetwork(constants.STATE_IMG_HEIGHT,
                              constants.STATE_IMG_WIDTH,
                              constants.N_IMAGES_PER_STATE, n_actions)
    criterion = torch.nn.MSELoss()
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = torch.optim.RMSprop(policy_net.parameters(),
                                    lr=constants.LEARNING_RATE,
                                    momentum=0.95)
    replay_memory = memory.ReplayMemory(constants.REPLAY_MEMORY_SIZE)

    steps_done = 0
    epoch = 0
    information = [[
        "epoch", "n_steps", "avg_reward", "avg_score", "n_episodes",
        "avg_q_value"
    ]]
    try:
        for i_episode in range(constants.N_EPISODES):

            cumulative_screenshot = []

            # Prepare the cumulative screenshot
            padding_image = torch.zeros(
                (1, constants.STATE_IMG_HEIGHT, constants.STATE_IMG_WIDTH))
            for i in range(constants.N_IMAGES_PER_STATE - 1):
                cumulative_screenshot.append(padding_image)

            env.reset()
            episode_score = 0
            episode_reward = 0

            screen_grayscale_state = get_screen(env)
            cumulative_screenshot.append(screen_grayscale_state)

            state = utils.process_state(cumulative_screenshot)

            prev_state_lives = constants.INITIAL_LIVES

            for i in range(constants.N_TIMESTEP_PER_EP):
                if constants.SHOW_SCREEN:
                    env.render()

                action = select_action(state, policy_net, steps_done, env)
                _, reward, done, info = env.step(action)
                episode_score += reward

                reward_tensor = None
                if info["ale.lives"] < prev_state_lives:
                    reward_tensor = torch.tensor([-1])
                    episode_reward += -1
                elif reward > 0:
                    reward_tensor = torch.tensor([1])
                    episode_reward += 1
                elif reward < 0:
                    reward_tensor = torch.tensor([-1])
                    episode_reward += -1
                else:
                    reward_tensor = torch.tensor([0])

                prev_state_lives = info["ale.lives"]

                screen_grayscale = get_screen(env)
                cumulative_screenshot.append(screen_grayscale)
                cumulative_screenshot.pop(
                    0
                )  # Deletes the first element of the list to save memory space

                if done:
                    next_state = None
                else:
                    next_state = utils.process_state(cumulative_screenshot)

                replay_memory.push(state, action, next_state, reward_tensor)

                if next_state is not None:
                    state.copy_(next_state)

                optimize_model(target_net, policy_net, replay_memory,
                               optimizer, criterion)
                steps_done += 1

                if done:
                    print("Episode:", i_episode, "Steps done:", steps_done,
                          "- Episode reward:", episode_reward,
                          "- Episode score:", episode_score)
                    break

                # Update target policy
                if steps_done % constants.TARGET_UPDATE == 0:
                    target_net.load_state_dict(policy_net.state_dict())

                # Epoch test
                if steps_done % constants.STEPS_PER_EPOCH == 0:
                    epoch += 1
                    epoch_reward_average, epoch_score_average, n_episodes, q_values_average = test.test_agent(
                        target_net, fixed_states)
                    information.append([
                        epoch, steps_done, epoch_reward_average,
                        epoch_score_average, n_episodes, q_values_average
                    ])
                    print("INFO", [
                        epoch, steps_done, epoch_reward_average,
                        epoch_score_average, n_episodes, q_values_average
                    ])

        # Save test information in dataframe
        print("Saving information...")
        information_numpy = numpy.array(information)
        dataframe_information = pandas.DataFrame(columns=information_numpy[0,
                                                                           0:],
                                                 data=information_numpy[1:,
                                                                        0:])
        dataframe_information.to_csv("info/results.csv")
        print(dataframe_information)

        # Save target parameters in file
        torch.save(target_net.state_dict(), "info/nn_parameters.txt")

    except KeyboardInterrupt:
        # Save test information in dataframe
        print("Saving information...")
        information_numpy = numpy.array(information)
        dataframe_information = pandas.DataFrame(columns=information_numpy[0,
                                                                           0:],
                                                 data=information_numpy[1:,
                                                                        0:])
        dataframe_information.to_csv("info/results.csv")
        print(dataframe_information)

        # Save target parameters in file
        torch.save(target_net.state_dict(), "info/nn_parameters.txt")

    env.close()