from snake_env import *
from Actor import Actor
import numpy as np
import tensorflow as tf
tf.keras.backend.set_floatx('float64')

env = snake_env()
actor = Actor((env.size, env.size), env.action_space, 0)
actor.model.load_weights('./weights/baseline.h5')
while True:
    state = env.reset()
    done = False
    while not done:
        env.render()
        delay(.1)
        probs = tf.nn.softmax(actor.model(np.expand_dims(state, 0))[0])
        action = np.random.choice(env.action_space, p=probs)
        
        next_state, reward, done, info = env.step(action)
        state = next_state
class Agent:
    def __init__(self):
        self.env = snake_env()
        self.state_dim = (self.env.size, self.env.size)
        self.action_dim = self.env.action_space
        self.actor = Actor(self.state_dim, self.action_dim, args.actor_lr)
        self.critic = Critic(self.state_dim, args.critic_lr)
        self.gamma = args.gamma

        if args.load_weights:
            self.actor.model.load_weights(args.load_weights)

        if args.dist_move_reward:
            self.env.set_reward(move_reward='-dist')

        # initialize video system only
        self.env.reset()


#         self.env.render()

    def MC(self, rewards, dones, next_value):
        '''
        Monte Carlo Estimation
        '''
        rewards = rewards.reshape(-1)
        returns = np.append(np.zeros_like(rewards), next_value, axis=-1)
        for t in reversed(range(rewards.shape[0])):
            returns[t] = rewards[t] + self.gamma * returns[t + 1] * (1 -
                                                                     dones[t])

        return returns[:-1].reshape(-1, 1)

    def advantage(self, returns, baselines):
        return returns - baselines

    def list_to_batch(self, _list):
        '''
        convert a list of single batches into a batch of len(_list)
        '''
        batch = _list[0]
        for elem in _list[1:]:
            batch = np.append(batch, elem, axis=0)
        return batch

    def train(self, max_updates=100, batch_size=64):
        episode_reward_list = []
        episode_length_list = []
        snake_length_list = []
        actor_loss = 0
        critic_loss = 0
        for up in tqdm(range(max_updates)):
            state_list = []
            action_list = []
            reward_list = []
            done_list = []
            step_reward_list = []
            step_snake_length = []

            state = self.env.reset()

            for ba in range(batch_size):
                #                 self.env.render()

                # data collection
                probs = tf.nn.softmax(
                    self.actor.model(np.expand_dims(state, 0))[0])
                action = np.random.choice(self.action_dim, p=probs)

                next_state, reward, done, info = self.env.step(action)
                step_reward_list.append(reward)
                step_snake_length.append(info['length'])

                if done:
                    # the end of an episode
                    episode_length_list.append(len(step_reward_list))
                    episode_reward_list.append(
                        sum(step_reward_list) / len(step_reward_list))
                    snake_length_list.append(
                        sum(step_snake_length) / len(step_snake_length))

                    n_episode = len(episode_reward_list)
                    if n_episode % args.log_interval == 0:
                        print(
                            f'\nEpisode: {n_episode}, Avg Reward: {episode_reward_list[-1]}'
                        )

                    step_reward_list = []
                    next_state = self.env.reset()

                    if max(episode_reward_list) == episode_reward_list[-1]:
                        self.actor.model.save_weights(args.save_weights)

                # make single batches
                state = np.expand_dims(state, 0)
                action = np.expand_dims(action, (0, 1))
                reward = np.expand_dims(reward, (0, 1))
                done = np.expand_dims(done, (0, 1))

                state_list.append(state)
                action_list.append(action)
                reward_list.append(reward)
                done_list.append(done)

                state = next_state

            # update the batch at once
            # convert list of batches into a batch of len(list)
            states = self.list_to_batch(state_list)
            actions = self.list_to_batch(action_list)
            rewards = self.list_to_batch(reward_list)
            dones = self.list_to_batch(done_list)

            next_value = self.critic.model(np.expand_dims(state, 0))[0]
            # using state, but actually it's next_state from the end of the loop above

            returns = self.MC(rewards, dones, next_value)

            advantages = self.advantage(returns,
                                        self.critic.model.predict(states))

            actor_loss = self.actor.train(states, actions, advantages)
            critic_loss = self.critic.train(states, returns)

        # save figure
        mean_n = 100
        n_episode = len(episode_reward_list)

        episode_reward_list = [
            sum(episode_reward_list[l:l + mean_n]) / mean_n
            for l in range(0, n_episode, mean_n)
        ]
        episode_length_list = [
            sum(episode_length_list[l:l + mean_n]) / mean_n
            for l in range(0, n_episode, mean_n)
        ]
        snake_length_list = [
            sum(snake_length_list[l:l + mean_n]) / mean_n
            for l in range(0, n_episode, mean_n)
        ]

        x = np.linspace(0, n_episode, len(episode_reward_list))

        plt.plot(x, episode_reward_list, label='Mean 100-Episode Reward')
        plt.plot(x, snake_length_list, label='Mean 100-Episode Snake Length')
        plt.plot(x,
                 episode_length_list,
                 label='Mean 100-Episode Episode Length')
        plt.legend()
        plt.xlabel('Episode')
        plt.title('A2C-snake_env')
        plt.savefig(args.save_figure)