示例#1
0
文件: dqn.py 项目: keygenx/Snakes
class DeepQNetworkAgent(AgentBase):
    """ Represents a Snake agent powered by DQN with experience replay. """
    def __init__(self, model, num_last_frames=4, memory_size=1000):
        """
        Create a new DQN-based agent.
        
        Args:
            model: a compiled DQN model.
            num_last_frames (int): the number of last frames the agent will consider.
            memory_size (int): memory size limit for experience replay (-1 for unlimited). 
        """
        assert model.input_shape[
            1] == num_last_frames, 'Model input shape should be (num_frames, grid_size, grid_size)'
        assert len(
            model.output_shape
        ) == 2, 'Model output shape should be (num_samples, num_actions)'

        self.model = model
        self.num_last_frames = num_last_frames
        self.memory = ExperienceReplay(
            (num_last_frames, ) + model.input_shape[-2:],
            model.output_shape[-1], memory_size)
        self.frames = None

    def begin_episode(self):
        """ Reset the agent for a new episode. """
        self.frames = None

    def get_last_frames(self, observation):
        """
        Get the pixels of the last `num_last_frames` observations, the current frame being the last.
        
        Args:
            observation: observation at the current timestep. 

        Returns:
            Observations for the last `num_last_frames` frames.
        """
        frame = observation
        if self.frames is None:
            self.frames = collections.deque([frame] * self.num_last_frames)
        else:
            self.frames.append(frame)
            self.frames.popleft()
        return np.expand_dims(self.frames, 0)

    def train(self,
              env,
              num_episodes=1000,
              batch_size=50,
              discount_factor=0.9,
              checkpoint_freq=None,
              exploration_range=(1.0, 0.1),
              exploration_phase_size=0.5):
        """
        Train the agent to perform well in the given Snake environment.
        
        Args:
            env:
                an instance of Snake environment.
            num_episodes (int):
                the number of episodes to run during the training.
            batch_size (int):
                the size of the learning sample for experience replay.
            discount_factor (float):
                discount factor (gamma) for computing the value function.
            checkpoint_freq (int):
                the number of episodes after which a new model checkpoint will be created.
            exploration_range (tuple):
                a (max, min) range specifying how the exploration rate should decay over time. 
            exploration_phase_size (float):
                the percentage of the training process at which
                the exploration rate should reach its minimum.
        """

        # Calculate the constant exploration decay speed for each episode.
        max_exploration_rate, min_exploration_rate = exploration_range
        exploration_decay = ((max_exploration_rate - min_exploration_rate) /
                             (num_episodes * exploration_phase_size))
        exploration_rate = max_exploration_rate

        for episode in range(num_episodes):
            # Reset the environment for the new episode.
            timestep = env.new_episode()
            self.begin_episode()
            game_over = False
            loss = 0.0

            # Observe the initial state.
            state = self.get_last_frames(timestep.observation)

            while not game_over:
                if np.random.random() < exploration_rate:
                    # Explore: take a random action.
                    action = np.random.randint(env.num_actions)
                else:
                    # Exploit: take the best known action for this state.
                    q = self.model.predict(state)
                    action = np.argmax(q[0])

                # Act on the environment.
                env.choose_action(action)
                timestep = env.timestep()

                # Remember a new piece of experience.
                reward = timestep.reward
                state_next = self.get_last_frames(timestep.observation)
                game_over = timestep.is_episode_end
                experience_item = [
                    state, action, reward, state_next, game_over
                ]
                self.memory.remember(*experience_item)
                state = state_next

                # Sample a random batch from experience.
                batch = self.memory.get_batch(model=self.model,
                                              batch_size=batch_size,
                                              discount_factor=discount_factor)

                # Learn on the batch.
                if batch:
                    inputs, targets = batch
                    loss += float(self.model.train_on_batch(inputs, targets))

            if checkpoint_freq and (episode % checkpoint_freq) == 0:
                self.model.save(f'dqn-{episode:08d}.model')

            if exploration_rate > min_exploration_rate:
                exploration_rate -= exploration_decay

            summary = 'Episode {:5d}/{:5d} | Loss {:8.4f} | Exploration {:.2f} | ' + \
                      'Fruits {:2d} | Timesteps {:4d} | Total Reward {:4d}'
            print(
                summary.format(
                    episode + 1,
                    num_episodes,
                    loss,
                    exploration_rate,
                    env.stats.fruits_eaten,
                    env.stats.timesteps_survived,
                    env.stats.sum_episode_rewards,
                ))

        self.model.save('dqn-final.model')

    def act(self, observation, reward):
        """
        Choose the next action to take.
        
        Args:
            observation: observable state for the current timestep. 
            reward: reward received at the beginning of the current timestep.

        Returns:
            The index of the action to take next.
        """
        state = self.get_last_frames(observation)
        q = self.model.predict(state)[0]
        return np.argmax(q)
示例#2
0
class DeepQNetworkAgent(AgentBase):
    """ Represents a Snake agent powered by DQN with experience replay. """
    def __init__(self, model, num_last_frames=4, memory_size=1000, output="."):
        """
        Create a new DQN-based agent.
        
        Args:
            model: a compiled DQN model.
            num_last_frames (int): the number of last frames the agent will consider.
            memory_size (int): memory size limit for experience replay (-1 for unlimited). 
            output (str): folder path to output model files.
        """
        assert model[0].input_shape[
            1] == num_last_frames, 'Model input shape should be (num_frames, grid_size, grid_size)'
        assert len(
            model[0].output_shape
        ) == 2, 'Model output shape should be (num_samples, num_actions)'

        self.model = model
        self.num_last_frames = num_last_frames
        self.memory = ExperienceReplay(
            (num_last_frames, ) + model[0].input_shape[-2:],
            model[0].output_shape[-1], memory_size)
        self.frames = None
        self.output = output
        self.num_frames = 0
        self.num_trained_frames = 0

    def begin_episode(self):
        """ Reset the agent for a new episode. """
        self.frames = None

    def get_last_frames(self, observation):
        """
        Get the pixels of the last `num_last_frames` observations, the current frame being the last.
        
        Args:
            observation: observation at the current timestep. 

        Returns:
            Observations for the last `num_last_frames` frames.
        """
        frame = observation
        if self.frames is None:
            self.frames = collections.deque([frame] * self.num_last_frames)
        else:
            self.frames.append(frame)
            self.frames.popleft()
        return np.expand_dims(self.frames, 0)

    def train(self,
              env,
              num_episodes=1000,
              batch_size=50,
              discount_factor=0.9,
              checkpoint_freq=None,
              method='dqn',
              multi_step='False'):
        """
        Train the agent to perform well in the given Snake environment.
        
        Args:
            env:
                an instance of Snake environment.
            num_episodes (int):
                the number of episodes to run during the training.
            batch_size (int):
                the size of the learning sample for experience replay.
            discount_factor (float):
                discount factor (gamma) for computing the value function.
            checkpoint_freq (int):
                the number of episodes after which a new model checkpoint will be created.
        """
        timestamp = time.strftime('%Y%m%d-%H%M%S')

        episode = 0
        while episode != num_episodes:
            episode += 1
            exploration_rate = 1 - 0.00009 * episode if episode < 10000 else (
                10 / np.sqrt(episode))

            # Reset the environment for the new episode.
            timestep = env.new_episode()
            self.begin_episode()
            game_over = False
            loss = 0.0
            model_to_udate = np.random.randint(0, 2) if method == 'ddqn' else 0

            # Observe the initial state.
            state = self.get_last_frames(timestep.observation)

            while not game_over:
                if np.random.random() < exploration_rate:
                    # Explore: take a random action.
                    action = np.random.randint(env.num_actions)
                else:
                    # Exploit: take the best known action for this state.
                    q = self.model[model_to_udate].predict(state)
                    action = np.argmax(q[0])

                # Act on the environment.
                env.choose_action(action)
                timestep = env.timestep()

                # Remember a new piece of experience.
                reward = timestep.reward
                state_next = self.get_last_frames(timestep.observation)

                if np.random.random() < exploration_rate:
                    # Explore: take a random action.
                    action_next = np.random.randint(env.num_actions)
                else:
                    # Exploit: take the best known action for this state.
                    q = self.model[model_to_udate].predict(state_next)
                    action_next = np.argmax(q[0])

                game_over = timestep.is_episode_end
                experience_item = [
                    state, action, reward, state_next, action_next, game_over
                ]
                self.memory.remember(*experience_item)
                state = state_next

                # Sample a random batch from experience.
                batch = self.memory.get_batch(
                    model=self.model,
                    batch_size=batch_size,
                    exploration_rate=exploration_rate,
                    discount_factor=discount_factor,
                    method=method,
                    model_to_udate=model_to_udate,
                    multi_step=multi_step)

                # Learn on the batch.
                if batch:
                    inputs, targets = batch
                    self.num_trained_frames += targets.size
                    loss += float(self.model[model_to_udate].train_on_batch(
                        inputs, targets))

                if Config.PRIORITIZED_REPLAY:
                    # Sample a random batch from experience.
                    batch = self.memory.get_batch(
                        model=self.model,
                        batch_size=batch_size,
                        exploration_rate=exploration_rate,
                        discount_factor=discount_factor,
                        method=method,
                        model_to_udate=model_to_udate,
                        multi_step=multi_step,
                        get_latest_replay=True)

                    # Learn on the batch.
                    if batch:
                        inputs, targets = batch
                        self.num_trained_frames += targets.size
                        replay_loss = float(
                            self.model[model_to_udate].train_on_batch(
                                inputs, targets))
                        input_loss = np.minimum(10, int(replay_loss))
                        self.memory.remember_prioritized_ratio(
                            np.ceil(
                                np.power(input_loss + 1,
                                         Config.PRIORITIZED_RATING)))

                        with open(f'{self.output}/training-loss.txt',
                                  'a') as f:
                            with redirect_stdout(f):
                                print(episode, self.num_frames, replay_loss)
                        f.close()

            if checkpoint_freq and (episode % checkpoint_freq) == 0:
                self.model[0].save(f'{self.output}/dqn-{episode:08d}.model')
                self.evaluate(env,
                              trained_episode=episode,
                              num_test_episode=15)

            self.num_frames += env.stats.timesteps_survived

            summary = 'Episode {:5d}/{:5d} | Loss {:8.4f} | Exploration {:.3f} | ' + \
                      'Fruits {:2d} | Timesteps {:4d} | Reward {:4d} | ' + \
                      'Memory {:6d} | Total Timesteps {:6d} | Trained Frames{:11d}'

            print(
                summary.format(episode + 1, num_episodes, loss,
                               exploration_rate, env.stats.fruits_eaten,
                               env.stats.timesteps_survived,
                               env.stats.sum_episode_rewards,
                               len(self.memory.memory), self.num_frames,
                               self.num_trained_frames))
            with open(f'{self.output}/training-log.txt', 'a') as f:
                with redirect_stdout(f):
                    print(
                        summary.format(episode + 1, num_episodes, loss,
                                       exploration_rate,
                                       env.stats.fruits_eaten,
                                       env.stats.timesteps_survived,
                                       env.stats.sum_episode_rewards,
                                       len(self.memory.memory),
                                       self.num_frames,
                                       self.num_trained_frames))
            f.close()

        self.model[0].save(f'{self.output}/dqn-final.model')
        self.evaluate(env, trained_episode=episode, num_test_episode=15)
        print('Training End - saved to ' + str(self.output))

    def act(self, observation, reward):
        """
        Choose the next action to take.
        
        Args:
            observation: observable state for the current timestep. 
            reward: reward received at the beginning of the current timestep.

        Returns:
            The index of the action to take next.
        """
        state = self.get_last_frames(observation)
        q = self.model[0].predict(state)[0]
        return np.argmax(q)

    def evaluate(self, env, trained_episode, num_test_episode):
        """
        Play a set of episodes using the specified Snake agent.
        Use the non-interactive command-line interface and print the summary statistics afterwards.
        
        Args:
            env: an instance of Snake environment.
            trained_episode (int): trained episodes.
            num_test_episode (int): the number of episodes to run.
        """

        fruit_stats = []
        timestep_stats = []
        reward_stats = []

        print()
        print('Playing:')

        for episode in range(num_test_episode):
            timestep = env.new_episode()
            self.begin_episode()
            game_over = False

            while not game_over:
                action = self.act(timestep.observation, timestep.reward)
                env.choose_action(action)
                timestep = env.timestep()
                game_over = timestep.is_episode_end

            fruit_stats.append(env.stats.fruits_eaten)
            timestep_stats.append(env.stats.timesteps_survived)
            reward_stats.append(env.stats.sum_episode_rewards)

            summary = 'Episode {:3d} / {:3d} | Timesteps {:4d} | Fruits {:2d} | Reward {:3d}'
            print(summary.format(episode + 1, num_test_episode, env.stats.timesteps_survived, +\
            env.stats.fruits_eaten, env.stats.sum_episode_rewards))

        print('Fruits eaten {:.1f} +/- stddev {:.1f}'.format(
            np.mean(fruit_stats), np.std(fruit_stats)))
        print('Reward {:.1f} +/- stddev {:.1f}'.format(np.mean(reward_stats),
                                                       np.std(reward_stats)))
        print()

        with open(f'{self.output}/training-stat.txt', 'a') as f:
            with redirect_stdout(f):
                summary = 'Episode {:7d} | Average Timesteps {:4.0f} | Average Fruits {:.1f} | Average Reward {:.1f}'
                print(
                    summary.format(trained_episode, np.mean(timestep_stats),
                                   np.mean(fruit_stats),
                                   np.mean(reward_stats)))