def __init__(self): pygame.init() self.agent = HumanAgent() self.env = None self.screen = None self.fps_clock = None self.timestep_watch = Stopwatch()
def create_agent(name, model, model2): """ Create a specific type of Snake AI agent. Args: name (str): key identifying the agent type. model: (optional) a pre-trained model required by certain agents. Returns: An instance of Snake agent. """ from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent if name == 'human': return HumanAgent() elif name == 'minimaxdqn': return MinimaxDeepQNetworkAgent(model_1 = model, model_2 = model2, memory_size=-1, num_last_frames=4) elif name == 'minimaxdqnsingle': return MinimaxSingleDeepQNetworkAgent(model = model, memory_size=-1, num_last_frames=4) elif name == 'random': return RandomActionAgent() raise KeyError(f'Unknown agent type: "{name}"')
def create_agent(name, model): """ Create a specific type of Snake AI agent. Args: name (str): key identifying the agent type. model: (optional) a pre-trained model required by certain agents. Returns: An instance of Snake agent. """ from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent if name == 'human': return HumanAgent() elif name == 'dqn': if model is None: raise ValueError('A model file is required for a DQN agent.') return DeepQNetworkAgent(model=model, memory_size=-1, num_last_frames=4) elif name == 'random': return RandomActionAgent() raise KeyError(f'Unknown agent type: "{name}"')
def create_agent(name, model, env): """ Create a specific type of Snake AI agent. Args: name (str): key identifying the agent type. model: (optional) a pre-trained model required by certain agents. env: an instance of Snake environment. Returns: An instance of Snake agent. """ from snakeai.agent import DeepQNetworkAgent, HumanAgent, RandomActionAgent if name == "human": return HumanAgent() elif name == "dqn": if model is None: raise ValueError("A model file is required for a DQN agent.") return DeepQNetworkAgent( model=model, memory_size=-1, num_last_frames=4, env_shape=env.observation_shape, num_actions=env.num_actions, ) elif name == "random": return RandomActionAgent() raise KeyError(f'Unknown agent type: "{name}"')
class PyGameGUI: """ Provides a Snake GUI powered by Pygame. """ FPS_LIMIT = 60 AI_TIMESTEP_DELAY = 100 HUMAN_TIMESTEP_DELAY = 500 CELL_SIZE = 10 SNAKE_CONTROL_KEYS = [ pygame.K_UP, pygame.K_LEFT, pygame.K_DOWN, pygame.K_RIGHT ] def __init__(self): pygame.init() self.agent = HumanAgent() self.env = None self.screen = None self.fps_clock = None self.timestep_watch = Stopwatch() def load_environment(self, environment): """ Load the RL environment into the GUI. """ self.env = environment screen_size = (self.env.field.size * self.CELL_SIZE, self.env.field.size * self.CELL_SIZE) self.screen = pygame.display.set_mode(screen_size) self.screen.fill(Colors.SCREEN_BACKGROUND) pygame.display.set_caption('Snake') def load_agent(self, agent): """ Load the RL agent into the GUI. """ self.agent = agent def render_cell(self, x, y): """ Draw the cell specified by the field coordinates. """ cell_coords = pygame.Rect( x * self.CELL_SIZE, y * self.CELL_SIZE, self.CELL_SIZE, self.CELL_SIZE, ) if self.env.field[x, y] == CellType.EMPTY: pygame.draw.rect(self.screen, Colors.SCREEN_BACKGROUND, cell_coords) else: color = Colors.CELL_TYPE[self.env.field[x, y]] pygame.draw.rect(self.screen, color, cell_coords, 1) internal_padding = self.CELL_SIZE // 6 * 2 internal_square_coords = cell_coords.inflate( (-internal_padding, -internal_padding)) pygame.draw.rect(self.screen, color, internal_square_coords) def render(self): """ Draw the entire game frame. """ for x in range(self.env.field.size): for y in range(self.env.field.size): self.render_cell(x, y) def map_key_to_snake_action(self, key): """ Convert a keystroke to an environment action. """ actions = [ SnakeAction.MAINTAIN_DIRECTION, SnakeAction.TURN_LEFT, SnakeAction.MAINTAIN_DIRECTION, SnakeAction.TURN_RIGHT, ] key_idx = self.SNAKE_CONTROL_KEYS.index(key) direction_idx = ALL_SNAKE_DIRECTIONS.index(self.env.snake.direction) return np.roll(actions, -key_idx)[direction_idx] def run(self, num_episodes=1): """ Run the GUI player for the specified number of episodes. """ pygame.display.update() self.fps_clock = pygame.time.Clock() try: for episode in range(num_episodes): self.run_episode() pygame.time.wait(1500) except QuitRequestedError: pass def run_episode(self): """ Run the GUI player for a single episode. """ # Initialize the environment. self.timestep_watch.reset() timestep_result = self.env.new_episode() self.agent.begin_episode() is_human_agent = isinstance(self.agent, HumanAgent) timestep_delay = self.HUMAN_TIMESTEP_DELAY if is_human_agent else self.AI_TIMESTEP_DELAY # Main game loop. running = True while running: action = SnakeAction.MAINTAIN_DIRECTION # Handle events. for event in pygame.event.get(): if event.type == pygame.KEYDOWN: if is_human_agent and event.key in self.SNAKE_CONTROL_KEYS: action = self.map_key_to_snake_action(event.key) if event.key == pygame.K_ESCAPE: raise QuitRequestedError if event.type == pygame.QUIT: raise QuitRequestedError # Update game state. timestep_timed_out = self.timestep_watch.time() >= timestep_delay human_made_move = is_human_agent and action != SnakeAction.MAINTAIN_DIRECTION if timestep_timed_out or human_made_move: self.timestep_watch.reset() if not is_human_agent: action = self.agent.act(timestep_result.observation, timestep_result.reward) self.env.choose_action(action) timestep_result = self.env.timestep() if timestep_result.is_episode_end: self.agent.end_episode() running = False # Render. self.render() score = self.env.snake.length - self.env.initial_snake_length pygame.display.set_caption(f'Snake [Score: {score:02d}]') pygame.display.update() self.fps_clock.tick(self.FPS_LIMIT)