コード例 #1
0
 def __init__(self):
     pygame.init()
     self.agent = HumanAgent()
     self.env = None
     self.screen = None
     self.fps_clock = None
     self.timestep_watch = Stopwatch()
コード例 #2
0
def create_agent(name, model, model2):
    """
    Create a specific type of Snake AI agent.

    Args:
        name (str): key identifying the agent type.
        model: (optional) a pre-trained model required by certain agents.

    Returns:
        An instance of Snake agent.
    """

    from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent

    if name == 'human':
        return HumanAgent()
    elif name == 'minimaxdqn':
        return MinimaxDeepQNetworkAgent(model_1 = model, model_2 = model2, memory_size=-1, num_last_frames=4)
    elif name == 'minimaxdqnsingle':
        return MinimaxSingleDeepQNetworkAgent(model = model, memory_size=-1, num_last_frames=4)
    
    elif name == 'random':
        return RandomActionAgent()

    raise KeyError(f'Unknown agent type: "{name}"')
コード例 #3
0
def create_agent(name, model):
    """
    Create a specific type of Snake AI agent.
    
    Args:
        name (str): key identifying the agent type.
        model: (optional) a pre-trained model required by certain agents.

    Returns:
        An instance of Snake agent.
    """

    from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent

    if name == 'human':
        return HumanAgent()
    elif name == 'dqn':
        if model is None:
            raise ValueError('A model file is required for a DQN agent.')
        return DeepQNetworkAgent(model=model,
                                 memory_size=-1,
                                 num_last_frames=4)
    elif name == 'random':
        return RandomActionAgent()

    raise KeyError(f'Unknown agent type: "{name}"')
コード例 #4
0
def create_agent(name, model, env):
    """
    Create a specific type of Snake AI agent.

    Args:
        name (str): key identifying the agent type.
        model: (optional) a pre-trained model required by certain agents.
        env: an instance of Snake environment.
    Returns:
        An instance of Snake agent.
    """

    from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent

    if name == "human":
        return HumanAgent()
    elif name == "dqn":
        if model is None:
            raise ValueError("A model file is required for a DQN agent.")
        return DeepQNetworkAgent(
            model=model,
            memory_size=-1,
            num_last_frames=4,
            env_shape=env.observation_shape,
            num_actions=env.num_actions,
        )
    elif name == "random":
        return RandomActionAgent()

    raise KeyError(f'Unknown agent type: "{name}"')
コード例 #5
0
class PyGameGUI:
    """ Provides a Snake GUI powered by Pygame. """

    FPS_LIMIT = 60
    AI_TIMESTEP_DELAY = 100
    HUMAN_TIMESTEP_DELAY = 500
    CELL_SIZE = 10

    SNAKE_CONTROL_KEYS = [
        pygame.K_UP, pygame.K_LEFT, pygame.K_DOWN, pygame.K_RIGHT
    ]

    def __init__(self):
        pygame.init()
        self.agent = HumanAgent()
        self.env = None
        self.screen = None
        self.fps_clock = None
        self.timestep_watch = Stopwatch()

    def load_environment(self, environment):
        """ Load the RL environment into the GUI. """
        self.env = environment
        screen_size = (self.env.field.size * self.CELL_SIZE,
                       self.env.field.size * self.CELL_SIZE)
        self.screen = pygame.display.set_mode(screen_size)
        self.screen.fill(Colors.SCREEN_BACKGROUND)
        pygame.display.set_caption('Snake')

    def load_agent(self, agent):
        """ Load the RL agent into the GUI. """
        self.agent = agent

    def render_cell(self, x, y):
        """ Draw the cell specified by the field coordinates. """
        cell_coords = pygame.Rect(
            x * self.CELL_SIZE,
            y * self.CELL_SIZE,
            self.CELL_SIZE,
            self.CELL_SIZE,
        )
        if self.env.field[x, y] == CellType.EMPTY:
            pygame.draw.rect(self.screen, Colors.SCREEN_BACKGROUND,
                             cell_coords)
        else:
            color = Colors.CELL_TYPE[self.env.field[x, y]]
            pygame.draw.rect(self.screen, color, cell_coords, 1)

            internal_padding = self.CELL_SIZE // 6 * 2
            internal_square_coords = cell_coords.inflate(
                (-internal_padding, -internal_padding))
            pygame.draw.rect(self.screen, color, internal_square_coords)

    def render(self):
        """ Draw the entire game frame. """
        for x in range(self.env.field.size):
            for y in range(self.env.field.size):
                self.render_cell(x, y)

    def map_key_to_snake_action(self, key):
        """ Convert a keystroke to an environment action. """
        actions = [
            SnakeAction.MAINTAIN_DIRECTION,
            SnakeAction.TURN_LEFT,
            SnakeAction.MAINTAIN_DIRECTION,
            SnakeAction.TURN_RIGHT,
        ]

        key_idx = self.SNAKE_CONTROL_KEYS.index(key)
        direction_idx = ALL_SNAKE_DIRECTIONS.index(self.env.snake.direction)
        return np.roll(actions, -key_idx)[direction_idx]

    def run(self, num_episodes=1):
        """ Run the GUI player for the specified number of episodes. """
        pygame.display.update()
        self.fps_clock = pygame.time.Clock()

        try:
            for episode in range(num_episodes):
                self.run_episode()
                pygame.time.wait(1500)
        except QuitRequestedError:
            pass

    def run_episode(self):
        """ Run the GUI player for a single episode. """

        # Initialize the environment.
        self.timestep_watch.reset()
        timestep_result = self.env.new_episode()
        self.agent.begin_episode()

        is_human_agent = isinstance(self.agent, HumanAgent)
        timestep_delay = self.HUMAN_TIMESTEP_DELAY if is_human_agent else self.AI_TIMESTEP_DELAY

        # Main game loop.
        running = True
        while running:
            action = SnakeAction.MAINTAIN_DIRECTION

            # Handle events.
            for event in pygame.event.get():
                if event.type == pygame.KEYDOWN:
                    if is_human_agent and event.key in self.SNAKE_CONTROL_KEYS:
                        action = self.map_key_to_snake_action(event.key)
                    if event.key == pygame.K_ESCAPE:
                        raise QuitRequestedError

                if event.type == pygame.QUIT:
                    raise QuitRequestedError

            # Update game state.
            timestep_timed_out = self.timestep_watch.time() >= timestep_delay
            human_made_move = is_human_agent and action != SnakeAction.MAINTAIN_DIRECTION

            if timestep_timed_out or human_made_move:
                self.timestep_watch.reset()

                if not is_human_agent:
                    action = self.agent.act(timestep_result.observation,
                                            timestep_result.reward)

                self.env.choose_action(action)
                timestep_result = self.env.timestep()

                if timestep_result.is_episode_end:
                    self.agent.end_episode()
                    running = False

            # Render.
            self.render()
            score = self.env.snake.length - self.env.initial_snake_length
            pygame.display.set_caption(f'Snake  [Score: {score:02d}]')
            pygame.display.update()
            self.fps_clock.tick(self.FPS_LIMIT)