def __init__(self, model_1, model_2, num_last_frames=4, memory_size=1000):
        """
        Create a new DQN-based agent.

        Args:
            model_1: a compiled DQN model for snake 1.
            model_2: a compiled DQN model for snake 2.
            num_last_frames (int): the number of last frames the agent will consider.
            memory_size (int): memory size limit for experience replay (-1 for unlimited).
        """
        assert model_1.input_shape[
            1] == num_last_frames, 'Model input shape should be (num_frames, grid_size, grid_size)'
        assert len(
            model_1.output_shape
        ) == 2, 'Model output shape should be (num_samples, num_actions)'
        assert model_2.input_shape[
            1] == num_last_frames, 'Model input shape should be (num_frames, grid_size, grid_size)'
        assert len(
            model_2.output_shape
        ) == 2, 'Model output shape should be (num_samples, num_actions)'

        self.model_1 = model_1
        self.model_2 = model_2
        self.num_last_frames = num_last_frames
        self.memory_1 = ExperienceReplay(
            (num_last_frames, ) + model_1.input_shape[-2:],
            model_1.output_shape[-1] // 3, memory_size)
        self.memory_2 = ExperienceReplay(
            (num_last_frames, ) + model_2.input_shape[-2:],
            model_2.output_shape[-1] // 3, memory_size)
        self.frames = None
예제 #2
0
    def __init__(self, model, num_last_frames=4, memory_size=1000, output="."):
        """
        Create a new DQN-based agent.
        
        Args:
            model: a compiled DQN model.
            num_last_frames (int): the number of last frames the agent will consider.
            memory_size (int): memory size limit for experience replay (-1 for unlimited). 
            output (str): folder path to output model files.
        """
        assert model[0].input_shape[
            1] == num_last_frames, 'Model input shape should be (num_frames, grid_size, grid_size)'
        assert len(
            model[0].output_shape
        ) == 2, 'Model output shape should be (num_samples, num_actions)'

        self.model = model
        self.num_last_frames = num_last_frames
        self.memory = ExperienceReplay(
            (num_last_frames, ) + model[0].input_shape[-2:],
            model[0].output_shape[-1], memory_size)
        self.frames = None
        self.output = output
        self.num_frames = 0
        self.num_trained_frames = 0
예제 #3
0
파일: dqn.py 프로젝트: exe1023/VIN-snake
    def __init__(self, model, num_last_frames=4, memory_size=1000, attention=1):
        """
        Create a new DQN-based agent.
        
        Args:
            model: a compiled DQN model.
            num_last_frames (int): the number of last frames the agent will consider.
            memory_size (int): memory size limit for experience replay (-1 for unlimited). 
        """
        #assert model.input_shape[0][1] == num_last_frames, 'Model input shape should be (num_frames, grid_size, grid_size)'
        #assert len(model.output_shape) == 2, 'Model output shape should be (num_samples, num_actions)'

        self.model = model
        self.num_last_frames = num_last_frames
        if attention != -1:
            self.memory = ExperienceReplay((num_last_frames,) + model.input_shape[0][-2:], model.output_shape[-1], memory_size, attention)
        else:
            self.memory = ExperienceReplay((num_last_frames,) + model.input_shape[-2:], model.output_shape[-1], memory_size, attention)
        self.frames = None
        self.attention = attention
예제 #4
0
    def __init__(self,
                 model,
                 env_shape,
                 num_actions,
                 num_last_frames=4,
                 memory_size=1000):
        """
        Create a new DQN-based agent.

        Args:
            model: a DQN model.
            env_shape (int, int): shape of the environment.
            num_actions (int): number of actions.
            num_last_frames (int): the number of last frames the agent will consider.
            memory_size (int): memory size limit for experience replay (-1 for unlimited).
        """
        self.model = model
        self.loss_fn = nn.MSELoss()
        self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.001)

        self.num_last_frames = num_last_frames
        self.memory = ExperienceReplay((num_last_frames, ) + env_shape,
                                       num_actions, memory_size)
        self.frames = None
class MinimaxDeepQNetworkAgent(AgentBase):
    def __init__(self, model_1, model_2, num_last_frames=4, memory_size=1000):
        """
        Create a new DQN-based agent.

        Args:
            model_1: a compiled DQN model for snake 1.
            model_2: a compiled DQN model for snake 2.
            num_last_frames (int): the number of last frames the agent will consider.
            memory_size (int): memory size limit for experience replay (-1 for unlimited).
        """
        assert model_1.input_shape[
            1] == num_last_frames, 'Model input shape should be (num_frames, grid_size, grid_size)'
        assert len(
            model_1.output_shape
        ) == 2, 'Model output shape should be (num_samples, num_actions)'
        assert model_2.input_shape[
            1] == num_last_frames, 'Model input shape should be (num_frames, grid_size, grid_size)'
        assert len(
            model_2.output_shape
        ) == 2, 'Model output shape should be (num_samples, num_actions)'

        self.model_1 = model_1
        self.model_2 = model_2
        self.num_last_frames = num_last_frames
        self.memory_1 = ExperienceReplay(
            (num_last_frames, ) + model_1.input_shape[-2:],
            model_1.output_shape[-1] // 3, memory_size)
        self.memory_2 = ExperienceReplay(
            (num_last_frames, ) + model_2.input_shape[-2:],
            model_2.output_shape[-1] // 3, memory_size)
        self.frames = None

    def begin_episode(self):
        """ Reset the agent for a new episode. """
        self.frames = None

    def get_last_frames(self, observation):
        """
        Get the pixels of the last `num_last_frames` observations, the current frame being the last.

        Args:
            observation: observation at the current timestep.

        Returns:
            Observations for the last `num_last_frames` frames.
        """
        frame = observation
        if self.frames is None:
            self.frames = collections.deque([frame] * self.num_last_frames)
        else:
            self.frames.append(frame)
            self.frames.popleft()
        return np.expand_dims(self.frames, 0).astype(np.float32) / 16

    def train(self,
              env,
              num_episodes=1000,
              batch_size=50,
              discount_factor=0.9,
              checkpoint_freq=None,
              exploration_range=(1.0, 0.1),
              exploration_phase_size=0.5):
        """
        Train the agent to perform well in the given Snake environment.

        Args:
            env:
                an instance of Snake environment.
            num_episodes (int):
                the number of episodes to run during the training.
            batch_size (int):
                the size of the learning sample for experience replay.
            discount_factor (float):
                discount factor (gamma) for computing the value function.
            checkpoint_freq (int):
                the number of episodes after which a new model checkpoint will be created.
            exploration_range (tuple):
                a (max, min) range specifying how the exploration rate should decay over time.
            exploration_phase_size (float):
                the percentage of the training process at which
                the exploration rate should reach its minimum.
        """

        # Calculate the constant exploration decay speed for each episode.
        max_exploration_rate, min_exploration_rate = exploration_range
        exploration_decay = ((max_exploration_rate - min_exploration_rate) /
                             (num_episodes * exploration_phase_size))
        exploration_rate = max_exploration_rate

        for episode in range(num_episodes):
            # Reset the environment for the new episode.
            timestep = env.new_episode()
            self.begin_episode()
            game_over = False
            loss_1 = 0.0
            loss_2 = 0.0
            alive_1 = True
            alive_2 = True
            # Observe the initial state.
            state = self.get_last_frames(timestep.observation)

            while not game_over:
                if np.random.random() < exploration_rate:
                    # Explore: take a random action.
                    action = (np.random.randint(env.num_actions),
                              np.random.randint(env.num_actions))
                else:
                    # Exploit: take the best known action for this state.
                    q1 = self.model_1.predict(state)
                    q2 = self.model_2.predict(state)
                    q1 = q1.reshape((env.num_actions, env.num_actions))
                    q2 = q2.reshape((env.num_actions, env.num_actions))
                    if alive_1 and alive_2:
                        action = (np.argmax(np.min(q1, axis=1)),
                                  np.argmax(np.min(q2, axis=1)))
                    elif alive_1:
                        action = (np.argmax(np.min(q1, axis=1)),
                                  np.argmin(np.max(q1, axis=0)))
                    elif alive_2:
                        action = (np.argmin(np.max(q2, axis=0)),
                                  np.argmax(np.min(q2, axis=1)))

                # Act on the environment.
                env.choose_action(action)
                timestep = env.timestep()

                # Remember a new piece of experience.
                reward_1, reward_2 = timestep.reward_1, timestep.reward_2
                state_next = self.get_last_frames(timestep.observation)
                game_over = timestep.is_episode_end

                experience_item_1 = [
                    state, action[0], action[1], reward_1, state_next,
                    game_over
                ]
                experience_item_2 = [
                    state, action[1], action[0], reward_2, state_next,
                    game_over
                ]
                self.memory_1.multi_remember(*experience_item_1)
                self.memory_2.multi_remember(*experience_item_2)
                state = state_next

                # Sample a random batch from experience.

                if alive_1:
                    batch = self.memory_1.get_multi_batch(
                        model=self.model_1,
                        batch_size=batch_size,
                        discount_factor=discount_factor)
                    # Learn on the batch.
                    if batch:
                        inputs, targets = batch
                        loss_1 += float(
                            self.model_1.train_on_batch(inputs, targets))
                    # Sample a random batch from experience.

                if alive_2:
                    batch = self.memory_2.get_multi_batch(
                        model=self.model_2,
                        batch_size=batch_size,
                        discount_factor=discount_factor)
                    # Learn on the batch.
                    if batch:
                        inputs, targets = batch
                        loss_2 += float(
                            self.model_2.train_on_batch(inputs, targets))

                alive_1 = timestep.alive_1
                alive_2 = timestep.alive_2

            if checkpoint_freq and (episode % checkpoint_freq) == 0:
                self.model_1.save(f'dqn-mm1-{episode:08d}.model')
                self.model_2.save(f'dqn-mm2-{episode:08d}.model')

            if exploration_rate > min_exploration_rate:
                exploration_rate -= exploration_decay

            summary = 'Episode {:5d}/{:5d} | Loss {:8.4f}, {:8.4f} | Exploration {:.2f} | ' + \
                      'Fruits {:2d}, {:2d} | Timesteps {:4d} | Total Reward {:4d}, {:4d}'
            print(
                summary.format(episode + 1, num_episodes, loss_1, loss_2,
                               exploration_rate, env.stats.fruits_eaten_1,
                               env.stats.fruits_eaten_2,
                               env.stats.timesteps_survived,
                               env.stats.sum_episode_rewards_1,
                               env.stats.sum_episode_rewards_2))

        self.model_1.save('dqn-mm1-final.model')
        self.model_2.save('dqn-mm2-final.model')

    def act(self, observation, reward, alive_1=True, alive_2=True):
        """
        Choose the next action to take.

        Args:
            observation: observable state for the current timestep.
            reward: reward received at the beginning of the current timestep.

        Returns:
            The index of the action to take next.
        """
        state = self.get_last_frames(observation)
        q1 = self.model_1.predict(state).reshape(3, 3)
        q2 = self.model_2.predict(state).reshape(3, 3)

        if alive_1 and alive_2:
            return (np.argmax(np.min(q1, axis=1)), np.argmax(np.min(q2,
                                                                    axis=1)))
        elif alive_1:
            return (np.argmax(np.min(q1, axis=1)), np.argmin(np.max(q1,
                                                                    axis=0)))
        elif alive_2:
            return (np.argmin(np.max(q2, axis=1)), np.argmax(np.min(q2,
                                                                    axis=0)))
예제 #6
0
파일: dqn.py 프로젝트: exe1023/VIN-snake
class DeepQNetworkAgent(AgentBase):
    """ Represents a Snake agent powered by DQN with experience replay. """

    def __init__(self, model, num_last_frames=4, memory_size=1000, attention=1):
        """
        Create a new DQN-based agent.
        
        Args:
            model: a compiled DQN model.
            num_last_frames (int): the number of last frames the agent will consider.
            memory_size (int): memory size limit for experience replay (-1 for unlimited). 
        """
        #assert model.input_shape[0][1] == num_last_frames, 'Model input shape should be (num_frames, grid_size, grid_size)'
        #assert len(model.output_shape) == 2, 'Model output shape should be (num_samples, num_actions)'

        self.model = model
        self.num_last_frames = num_last_frames
        if attention != -1:
            self.memory = ExperienceReplay((num_last_frames,) + model.input_shape[0][-2:], model.output_shape[-1], memory_size, attention)
        else:
            self.memory = ExperienceReplay((num_last_frames,) + model.input_shape[-2:], model.output_shape[-1], memory_size, attention)
        self.frames = None
        self.attention = attention

    def begin_episode(self):
        """ Reset the agent for a new episode. """
        self.frames = None

    def get_last_frames(self, observation):
        """
        Get the pixels of the last `num_last_frames` observations, the current frame being the last.
        
        Args:
            observation: observation at the current timestep. 

        Returns:
            Observations for the last `num_last_frames` frames.
        """
        frame = observation
        if self.frames is None:
            self.frames = collections.deque([frame] * self.num_last_frames)
        else:
            self.frames.append(frame)
            self.frames.popleft()
        return np.expand_dims(self.frames, 0)

    def train(self, env, num_episodes=1000, batch_size=50, discount_factor=0.9, checkpoint_freq=None,
              exploration_range=(1.0, 0.1), exploration_phase_size=0.5):
        """
        Train the agent to perform well in the given Snake environment.
        
        Args:
            env:
                an instance of Snake environment.
            num_episodes (int):
                the number of episodes to run during the training.
            batch_size (int):
                the size of the learning sample for experience replay.
            discount_factor (float):
                discount factor (gamma) for computing the value function.
            checkpoint_freq (int):
                the number of episodes after which a new model checkpoint will be created.
            exploration_range (tuple):
                a (max, min) range specifying how the exploration rate should decay over time. 
            exploration_phase_size (float):
                the percentage of the training process at which
                the exploration rate should reach its minimum.
        """

        # Calculate the constant exploration decay speed for each episode.
        max_exploration_rate, min_exploration_rate = exploration_range
        exploration_decay = ((max_exploration_rate - min_exploration_rate) / (num_episodes * exploration_phase_size))
        exploration_rate = max_exploration_rate

        for episode in range(num_episodes):
            # Reset the environment for the new episode.
            timestep, position = env.new_episode()
            self.begin_episode()
            game_over = False
            loss = 0.0

            # Observe the initial state.
            state = self.get_last_frames(timestep.observation)

            while not game_over:
                if np.random.random() < exploration_rate:
                    # Explore: take a random action.
                    action = np.random.randint(env.num_actions)
                else:
                    # Exploit: take the best known action for this state.
                    s = np.array([(position[0], position[1])])
                    if self.attention == 0:
                        q = self.model.predict([state, state, s])
                    elif self.attention > 0:
                        q = self.model.predict([state, state])
                    else:
                        q = self.model.predict(state)
                    action = np.argmax(q[0])

                # Act on the environment.
                env.choose_action(action)
                timestep, position_next = env.timestep()

                # Remember a new piece of experience.
                reward = timestep.reward
                state_next = self.get_last_frames(timestep.observation)
                game_over = timestep.is_episode_end
                experience_item = [state, position, action, reward, state_next, position_next, game_over]
                self.memory.remember(*experience_item)
                state = state_next
                position = position_next

                # Sample a random batch from experience.
                batch = self.memory.get_batch(
                    model=self.model,
                    batch_size=batch_size,
                    discount_factor=discount_factor
                )
                # Learn on the batch.
                if batch:
                    inputs, s, targets = batch
                    #print(episode)
                    #print(inputs)
                    #print(targets)
                    if self.attention == 0:
                        loss += float(self.model.train_on_batch([inputs, inputs, s], targets))
                    elif self.attention > 0:
                        loss += float(self.model.train_on_batch([inputs, inputs], targets))
                    else:
                        loss += float(self.model.train_on_batch(inputs, targets))


            if checkpoint_freq and (episode % checkpoint_freq) == 0:
                #self.model.save(f'dqn-{episode:08d}.model')
                self.model.save('dqn-' + str(episode) + '.model')

            if exploration_rate > min_exploration_rate:
                exploration_rate -= exploration_decay

            summary = 'Episode {:5d}/{:5d} | Loss {:8.4f} | Exploration {:.2f} | ' + \
                      'Fruits {:2d} | Timesteps {:4d} | Total Reward {:4d}'
            print(summary.format(
                episode + 1, num_episodes, loss, exploration_rate,
                env.stats.fruits_eaten, env.stats.timesteps_survived, env.stats.sum_episode_rewards,
            ))
            print('Episode')

        self.model.save('dqn-final.model')

    def act(self, observation, position, reward, attention=1):
        """
        Choose the next action to take.
        
        Args:
            observation: observable state for the current timestep. 
            reward: reward received at the beginning of the current timestep.

        Returns:
            The index of the action to take next.
        """
        state = self.get_last_frames(observation)
        s = np.array([(position[0], position[1])])
        if attention > 0:
            q = self.model.predict([state, state])[0]
        elif attention == 0:
            q = self.model.predict([state, state, s])[0]
        else:
            q = self.model.predict(state)[0]
        return np.argmax(q)

    def visualize(self, observation, position, reward, attention=1, visualize=None):
        """
        Choose the next action to take.
        
        Args:
            observation: observable state for the current timestep. 
            reward: reward received at the beginning of the current timestep.

        Returns:
            The index of the action to take next.
        """
        state = self.get_last_frames(observation)
        s = np.array([(position[0], position[1])])
        if attention > 0:
            q = self.model.predict([state, state])[0]
            return visualize([state])
        else:
            q = self.model.predict([state, state, s])[0]
            return visualize([state, s]) , np.argmax(q)
예제 #7
0
class DeepQNetworkAgent(AgentBase):
    """ Represents a Snake agent powered by DQN with experience replay. """
    def __init__(self,
                 model,
                 env_shape,
                 num_actions,
                 num_last_frames=4,
                 memory_size=1000):
        """
        Create a new DQN-based agent.

        Args:
            model: a DQN model.
            env_shape (int, int): shape of the environment.
            num_actions (int): number of actions.
            num_last_frames (int): the number of last frames the agent will consider.
            memory_size (int): memory size limit for experience replay (-1 for unlimited).
        """
        self.model = model
        self.loss_fn = nn.MSELoss()
        self.optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.001)

        self.num_last_frames = num_last_frames
        self.memory = ExperienceReplay((num_last_frames, ) + env_shape,
                                       num_actions, memory_size)
        self.frames = None

    def begin_episode(self):
        """ Reset the agent for a new episode. """
        self.frames = None

    def get_last_frames(self, observation):
        """
        Get the pixels of the last `num_last_frames` observations, the current frame being the last.

        Args:
            observation: observation at the current timestep.

        Returns:
            Observations for the last `num_last_frames` frames.
        """
        frame = observation
        if self.frames is None:
            self.frames = collections.deque([frame] * self.num_last_frames)
        else:
            self.frames.append(frame)
            self.frames.popleft()
        return np.expand_dims(self.frames, 0)

    def train(
            self,
            env,
            num_episodes=1000,
            batch_size=50,
            discount_factor=0.9,
            checkpoint_freq=None,
            exploration_range=(1.0, 0.1),
            exploration_phase_size=0.5,
    ):
        """
        Train the agent to perform well in the given Snake environment.

        Args:
            env:
                an instance of Snake environment.
            num_episodes (int):
                the number of episodes to run during the training.
            batch_size (int):
                the size of the learning sample for experience replay.
            discount_factor (float):
                discount factor (gamma) for computing the value function.
            checkpoint_freq (int):
                the number of episodes after which a new model checkpoint will be created.
            exploration_range (tuple):
                a (max, min) range specifying how the exploration rate should decay over time.
            exploration_phase_size (float):
                the percentage of the training process at which
                the exploration rate should reach its minimum.
        """

        # Calculate the constant exploration decay speed for each episode.
        max_exploration_rate, min_exploration_rate = exploration_range
        exploration_decay = (max_exploration_rate - min_exploration_rate) / (
            num_episodes * exploration_phase_size)
        exploration_rate = max_exploration_rate

        for episode in range(num_episodes):
            # Reset the environment for the new episode.
            timestep = env.new_episode()
            self.begin_episode()
            game_over = False
            loss = 0.0

            # Observe the initial state.
            state = self.get_last_frames(timestep.observation)
            while not game_over:
                if np.random.random() < exploration_rate:
                    # Explore: take a random action.
                    action = np.random.randint(env.num_actions)
                else:
                    # Exploit: take the best known action for this state.
                    q = self.model(torch.Tensor(state))
                    action = np.argmax(q[0].detach()).item()

                # Act on the environment.
                env.choose_action(action)
                timestep = env.timestep()

                # Remember a new piece of experience.
                reward = timestep.reward
                state_next = self.get_last_frames(timestep.observation)
                game_over = timestep.is_episode_end
                experience_item = [
                    state, action, reward, state_next, game_over
                ]
                self.memory.remember(*experience_item)
                state = state_next

                # Sample a random batch from experience.
                batch = self.memory.get_batch(
                    model=self.model,
                    batch_size=batch_size,
                    discount_factor=discount_factor,
                )
                # Learn on the batch.
                if batch:
                    inputs, targets = batch
                    self.optimizer.zero_grad()
                    predictions = self.model(torch.Tensor(inputs))
                    batch_loss = self.loss_fn(predictions,
                                              torch.Tensor(targets))
                    loss += batch_loss
                    # Backpropagation
                    batch_loss.backward()
                    self.optimizer.step()

            if checkpoint_freq and (episode % checkpoint_freq) == 0:
                torch.save(self.model, f"dqn-{episode:08d}.model")

            if exploration_rate > min_exploration_rate:
                exploration_rate -= exploration_decay

            summary = (
                "Episode {:5d}/{:5d} | Loss {:8.4f} | Exploration {:.2f} | " +
                "Fruits {:2d} | Timesteps {:4d} | Total Reward {:4d}")
            print(
                summary.format(
                    episode + 1,
                    num_episodes,
                    loss,
                    exploration_rate,
                    env.stats.fruits_eaten,
                    env.stats.timesteps_survived,
                    env.stats.sum_episode_rewards,
                ))

        torch.save(self.model, "dqn-final.model")

    def act(self, observation, reward):
        """
        Choose the next action to take.

        Args:
            observation: observable state for the current timestep.
            reward: reward received at the beginning of the current timestep.

        Returns:
            The index of the action to take next.
        """
        state = self.get_last_frames(observation)
        with torch.no_grad():
            q = self.model(torch.Tensor(state))
        action = np.argmax(q[0]).item()
        return action
예제 #8
0
class DeepQNetworkAgent(AgentBase):
    """ Represents a Snake agent powered by DQN with experience replay. """
    def __init__(self, model, num_last_frames=4, memory_size=1000, output="."):
        """
        Create a new DQN-based agent.
        
        Args:
            model: a compiled DQN model.
            num_last_frames (int): the number of last frames the agent will consider.
            memory_size (int): memory size limit for experience replay (-1 for unlimited). 
            output (str): folder path to output model files.
        """
        assert model[0].input_shape[
            1] == num_last_frames, 'Model input shape should be (num_frames, grid_size, grid_size)'
        assert len(
            model[0].output_shape
        ) == 2, 'Model output shape should be (num_samples, num_actions)'

        self.model = model
        self.num_last_frames = num_last_frames
        self.memory = ExperienceReplay(
            (num_last_frames, ) + model[0].input_shape[-2:],
            model[0].output_shape[-1], memory_size)
        self.frames = None
        self.output = output
        self.num_frames = 0
        self.num_trained_frames = 0

    def begin_episode(self):
        """ Reset the agent for a new episode. """
        self.frames = None

    def get_last_frames(self, observation):
        """
        Get the pixels of the last `num_last_frames` observations, the current frame being the last.
        
        Args:
            observation: observation at the current timestep. 

        Returns:
            Observations for the last `num_last_frames` frames.
        """
        frame = observation
        if self.frames is None:
            self.frames = collections.deque([frame] * self.num_last_frames)
        else:
            self.frames.append(frame)
            self.frames.popleft()
        return np.expand_dims(self.frames, 0)

    def train(self,
              env,
              num_episodes=1000,
              batch_size=50,
              discount_factor=0.9,
              checkpoint_freq=None,
              method='dqn',
              multi_step='False'):
        """
        Train the agent to perform well in the given Snake environment.
        
        Args:
            env:
                an instance of Snake environment.
            num_episodes (int):
                the number of episodes to run during the training.
            batch_size (int):
                the size of the learning sample for experience replay.
            discount_factor (float):
                discount factor (gamma) for computing the value function.
            checkpoint_freq (int):
                the number of episodes after which a new model checkpoint will be created.
        """
        timestamp = time.strftime('%Y%m%d-%H%M%S')

        episode = 0
        while episode != num_episodes:
            episode += 1
            exploration_rate = 1 - 0.00009 * episode if episode < 10000 else (
                10 / np.sqrt(episode))

            # Reset the environment for the new episode.
            timestep = env.new_episode()
            self.begin_episode()
            game_over = False
            loss = 0.0
            model_to_udate = np.random.randint(0, 2) if method == 'ddqn' else 0

            # Observe the initial state.
            state = self.get_last_frames(timestep.observation)

            while not game_over:
                if np.random.random() < exploration_rate:
                    # Explore: take a random action.
                    action = np.random.randint(env.num_actions)
                else:
                    # Exploit: take the best known action for this state.
                    q = self.model[model_to_udate].predict(state)
                    action = np.argmax(q[0])

                # Act on the environment.
                env.choose_action(action)
                timestep = env.timestep()

                # Remember a new piece of experience.
                reward = timestep.reward
                state_next = self.get_last_frames(timestep.observation)

                if np.random.random() < exploration_rate:
                    # Explore: take a random action.
                    action_next = np.random.randint(env.num_actions)
                else:
                    # Exploit: take the best known action for this state.
                    q = self.model[model_to_udate].predict(state_next)
                    action_next = np.argmax(q[0])

                game_over = timestep.is_episode_end
                experience_item = [
                    state, action, reward, state_next, action_next, game_over
                ]
                self.memory.remember(*experience_item)
                state = state_next

                # Sample a random batch from experience.
                batch = self.memory.get_batch(
                    model=self.model,
                    batch_size=batch_size,
                    exploration_rate=exploration_rate,
                    discount_factor=discount_factor,
                    method=method,
                    model_to_udate=model_to_udate,
                    multi_step=multi_step)

                # Learn on the batch.
                if batch:
                    inputs, targets = batch
                    self.num_trained_frames += targets.size
                    loss += float(self.model[model_to_udate].train_on_batch(
                        inputs, targets))

                if Config.PRIORITIZED_REPLAY:
                    # Sample a random batch from experience.
                    batch = self.memory.get_batch(
                        model=self.model,
                        batch_size=batch_size,
                        exploration_rate=exploration_rate,
                        discount_factor=discount_factor,
                        method=method,
                        model_to_udate=model_to_udate,
                        multi_step=multi_step,
                        get_latest_replay=True)

                    # Learn on the batch.
                    if batch:
                        inputs, targets = batch
                        self.num_trained_frames += targets.size
                        replay_loss = float(
                            self.model[model_to_udate].train_on_batch(
                                inputs, targets))
                        input_loss = np.minimum(10, int(replay_loss))
                        self.memory.remember_prioritized_ratio(
                            np.ceil(
                                np.power(input_loss + 1,
                                         Config.PRIORITIZED_RATING)))

                        with open(f'{self.output}/training-loss.txt',
                                  'a') as f:
                            with redirect_stdout(f):
                                print(episode, self.num_frames, replay_loss)
                        f.close()

            if checkpoint_freq and (episode % checkpoint_freq) == 0:
                self.model[0].save(f'{self.output}/dqn-{episode:08d}.model')
                self.evaluate(env,
                              trained_episode=episode,
                              num_test_episode=15)

            self.num_frames += env.stats.timesteps_survived

            summary = 'Episode {:5d}/{:5d} | Loss {:8.4f} | Exploration {:.3f} | ' + \
                      'Fruits {:2d} | Timesteps {:4d} | Reward {:4d} | ' + \
                      'Memory {:6d} | Total Timesteps {:6d} | Trained Frames{:11d}'

            print(
                summary.format(episode + 1, num_episodes, loss,
                               exploration_rate, env.stats.fruits_eaten,
                               env.stats.timesteps_survived,
                               env.stats.sum_episode_rewards,
                               len(self.memory.memory), self.num_frames,
                               self.num_trained_frames))
            with open(f'{self.output}/training-log.txt', 'a') as f:
                with redirect_stdout(f):
                    print(
                        summary.format(episode + 1, num_episodes, loss,
                                       exploration_rate,
                                       env.stats.fruits_eaten,
                                       env.stats.timesteps_survived,
                                       env.stats.sum_episode_rewards,
                                       len(self.memory.memory),
                                       self.num_frames,
                                       self.num_trained_frames))
            f.close()

        self.model[0].save(f'{self.output}/dqn-final.model')
        self.evaluate(env, trained_episode=episode, num_test_episode=15)
        print('Training End - saved to ' + str(self.output))

    def act(self, observation, reward):
        """
        Choose the next action to take.
        
        Args:
            observation: observable state for the current timestep. 
            reward: reward received at the beginning of the current timestep.

        Returns:
            The index of the action to take next.
        """
        state = self.get_last_frames(observation)
        q = self.model[0].predict(state)[0]
        return np.argmax(q)

    def evaluate(self, env, trained_episode, num_test_episode):
        """
        Play a set of episodes using the specified Snake agent.
        Use the non-interactive command-line interface and print the summary statistics afterwards.
        
        Args:
            env: an instance of Snake environment.
            trained_episode (int): trained episodes.
            num_test_episode (int): the number of episodes to run.
        """

        fruit_stats = []
        timestep_stats = []
        reward_stats = []

        print()
        print('Playing:')

        for episode in range(num_test_episode):
            timestep = env.new_episode()
            self.begin_episode()
            game_over = False

            while not game_over:
                action = self.act(timestep.observation, timestep.reward)
                env.choose_action(action)
                timestep = env.timestep()
                game_over = timestep.is_episode_end

            fruit_stats.append(env.stats.fruits_eaten)
            timestep_stats.append(env.stats.timesteps_survived)
            reward_stats.append(env.stats.sum_episode_rewards)

            summary = 'Episode {:3d} / {:3d} | Timesteps {:4d} | Fruits {:2d} | Reward {:3d}'
            print(summary.format(episode + 1, num_test_episode, env.stats.timesteps_survived, +\
            env.stats.fruits_eaten, env.stats.sum_episode_rewards))

        print('Fruits eaten {:.1f} +/- stddev {:.1f}'.format(
            np.mean(fruit_stats), np.std(fruit_stats)))
        print('Reward {:.1f} +/- stddev {:.1f}'.format(np.mean(reward_stats),
                                                       np.std(reward_stats)))
        print()

        with open(f'{self.output}/training-stat.txt', 'a') as f:
            with redirect_stdout(f):
                summary = 'Episode {:7d} | Average Timesteps {:4.0f} | Average Fruits {:.1f} | Average Reward {:.1f}'
                print(
                    summary.format(trained_episode, np.mean(timestep_stats),
                                   np.mean(fruit_stats),
                                   np.mean(reward_stats)))