def __init__(self, config, verbose=1): """ Create a new Snake RL environment. Args: config (dict): level configuration, typically found in JSON configs. verbose (int): verbosity level: 0 = do not write any debug information; 1 = write a CSV file containing the statistics for every episode; 2 = same as 1, but also write a full log file containing the state of each timestep. """ self.field = Field(level_map=config['field']) self.snake = None self.fruit = None self.initial_snake_length = config['initial_snake_length'] self.rewards = config['rewards'] self.max_step_limit = config.get('max_step_limit', 1000) self.is_game_over = False self.timestep_index = 0 self.current_action = None self.stats = EpisodeStatistics() self.verbose = verbose self.debug_file = None self.stats_file = None
def test_get_random_empty_cell_only_one_cell_left_returns_it(): field = Field([ '#####', '#Sss#', '#O.s#', '#sss#', '#####', ]) field.create_level() assert field.get_random_empty_cell() == (2, 2)
def test_get_random_empty_cell_many_available_returns_any(): field = Field([ '#####', '#.Ss#', '#..s#', '#.ss#', '#####', ]) field.create_level() assert field.get_random_empty_cell() in ((1, 1), (1, 2), (2, 2), (1, 3))
def test_get_random_empty_cell_no_cells_left_throws(): field = Field([ '####', '#Ss#', '#ss#', '####', ]) field.create_level() with pytest.raises(IndexError): field.get_random_empty_cell()
def test_find_snake_head_no_snake_on_map_throws(): field = Field([ '#####', '#...#', '#...#', '#...#', '#####', ]) field.create_level() with pytest.raises(ValueError): field.find_snake_head()
def test_place_snake_given_position_on_map_places_correctly(): field = Field(small_level_map) field.create_level() snake = Snake(field.find_snake_head(), length=3) field.place_snake(snake) assert str(field).split('\n') == [ '#######', '#.....#', '#.....#', '#..S..#', '#..s..#', '#..s..#', '#######', ]
def test_find_snake_head_snake_present_finds_it(): field = Field(small_level_map) field.create_level() assert field.find_snake_head() == (3, 3)
def test_field_created_from_map_renders_same_map(): field = Field(small_level_map) field.create_level() assert str(field).split('\n') == small_level_map
class Environment(object): """ Represents the RL environment for the Snake game that implements the game logic, provides rewards for the agent and keeps track of game statistics. """ def __init__(self, config, verbose=1): """ Create a new Snake RL environment. Args: config (dict): level configuration, typically found in JSON configs. verbose (int): verbosity level: 0 = do not write any debug information; 1 = write a CSV file containing the statistics for every episode; 2 = same as 1, but also write a full log file containing the state of each timestep. """ self.field = Field(level_map=config['field']) self.snake = None self.fruit = None self.initial_snake_length = config['initial_snake_length'] self.rewards = config['rewards'] self.max_step_limit = config.get('max_step_limit', 1000) self.is_game_over = False self.timestep_index = 0 self.current_action = None self.stats = EpisodeStatistics() self.verbose = verbose self.debug_file = None self.stats_file = None def seed(self, value): """ Initialize the random state of the environment to make results reproducible. """ random.seed(value) np.random.seed(value) @property def observation_shape(self): """ Get the shape of the state observed at each timestep. """ return self.field.size_x, self.field.size_y @property def num_actions(self): """ Get the number of actions the agent can take. """ return len(ALL_SNAKE_ACTIONS) def new_episode(self): """ Reset the environment and begin a new episode. """ self.field.create_level() self.stats.reset() self.timestep_index = 0 self.snake = Snake(self.field.find_snake_head(), length=self.initial_snake_length) self.field.place_snake(self.snake) self.generate_fruit() self.current_action = None self.is_game_over = False result = TimestepResult( observation=self.get_observation(), reward=0, is_episode_end=self.is_game_over ) self.record_timestep_stats(result) return result def record_timestep_stats(self, result): """ Record environment statistics according to the verbosity level. """ timestamp = time.strftime('%Y%m%d-%H%M%S') # Write CSV header for the stats file. if self.verbose >= 1 and self.stats_file is None: self.stats_file = open(f'debug_stats/snake-env-{timestamp}.csv', 'w') stats_csv_header_line = self.stats.to_dataframe()[:0].to_csv(index=None) print(stats_csv_header_line, file=self.stats_file, end='', flush=True) # Create a blank debug log file. if self.verbose >= 2 and self.debug_file is None: self.debug_file = open(f'debug_log/snake-env-{timestamp}.log', 'w') self.stats.record_timestep(self.current_action, result) self.stats.timesteps_survived = self.timestep_index if self.verbose >= 2: print(result, file=self.debug_file) # Log episode stats if the appropriate verbosity level is set. if result.is_episode_end: if self.verbose >= 1: stats_csv_line = self.stats.to_dataframe().to_csv(header=False, index=None) print(stats_csv_line, file=self.stats_file, end='', flush=True) if self.verbose >= 2: print(self.stats, file=self.debug_file) def get_observation(self): """ Observe the state of the environment. """ return np.copy(self.field._cells) def choose_action(self, action): """ Choose the action that will be taken at the next timestep. """ self.current_action = action if action == SnakeAction.TURN_LEFT: self.snake.turn_left() elif action == SnakeAction.TURN_RIGHT: self.snake.turn_right() def timestep(self): """ Execute the timestep and return the new observable state. """ self.timestep_index += 1 reward = 0 old_head = self.snake.head old_tail = self.snake.tail # Are we about to eat the fruit? if self.snake.peek_next_move() == self.fruit: self.snake.grow() self.generate_fruit() old_tail = None reward += self.rewards['ate_fruit'] * self.snake.length self.stats.fruits_eaten += 1 # If not, just move forward. else: self.snake.move() reward += self.rewards['timestep'] self.field.update_snake_footprint(old_head, old_tail, self.snake.head) # Hit a wall or own body? if not self.is_alive(): if self.has_hit_wall(): self.stats.termination_reason = 'hit_wall' if self.has_hit_own_body(): self.stats.termination_reason = 'hit_own_body' self.field[self.snake.head] = CellType.SNAKE_HEAD self.is_game_over = True reward = self.rewards['died'] # Exceeded the limit of moves? if self.timestep_index >= self.max_step_limit: self.is_game_over = True self.stats.termination_reason = 'timestep_limit_exceeded' result = TimestepResult( observation=self.get_observation(), reward=reward, is_episode_end=self.is_game_over ) self.record_timestep_stats(result) return result def generate_fruit(self, position=None): """ Generate a new fruit at a random unoccupied cell. """ if position is None: position = self.field.get_random_empty_cell() self.field[position] = CellType.FRUIT self.fruit = position def has_hit_wall(self): """ True if the snake has hit a wall, False otherwise. """ return self.field[self.snake.head] == CellType.WALL def has_hit_own_body(self): """ True if the snake has hit its own body, False otherwise. """ return self.field[self.snake.head] == CellType.SNAKE_BODY def is_alive(self): """ True if the snake is still alive, False otherwise. """ return not self.has_hit_wall() and not self.has_hit_own_body()