Пример #1
0
def test_single_move_moves_entire_body():
    snake = Snake(Point(4, 5))
    snake.move()
    assert snake.head == (4, 4)
    assert snake.length == 3
    assert snake.tail == (4, 6)
    assert list(snake.body) == [(4, 4), (4, 5), (4, 6)]
Пример #2
0
def test_multiple_moves_respect_direction():
    snake = Snake(Point(4, 5), length=5)
    snake.move()

    snake.direction = SnakeDirection.WEST
    snake.move()
    snake.move()

    snake.direction = SnakeDirection.SOUTH
    snake.move()

    assert list(snake.body) == [(2, 5), (2, 4), (3, 4), (4, 4), (4, 5)]
Пример #3
0
def test_peek_next_move_respects_current_direction():
    snake = Snake(Point(2, 5))

    snake.direction = SnakeDirection.NORTH
    assert snake.peek_next_move() == (2, 4)

    snake.direction = SnakeDirection.EAST
    assert snake.peek_next_move() == (3, 5)

    snake.direction = SnakeDirection.SOUTH
    assert snake.peek_next_move() == (2, 6)

    snake.direction = SnakeDirection.WEST
    assert snake.peek_next_move() == (1, 5)
Пример #4
0
def test_place_snake_with_default_length_applies_default_layout():
    snake = Snake(Point(4, 2))
    assert snake.head == Point(4, 2)
    assert snake.tail == Point(4, 4)
    assert snake.length == 3
    assert snake.direction == SnakeDirection.NORTH
    assert list(snake.body) == [(4, 2), (4, 3), (4, 4)]
Пример #5
0
    def new_episode(self):
        """ Reset the environment and begin a new episode. """
        self.field.create_level()
        self.stats.reset()
        self.timestep_index = 0

        self.snake = Snake(self.field.find_snake_head(), length=self.initial_snake_length)
        self.field.place_snake(self.snake)
        self.generate_fruit()
        self.current_action = None
        self.is_game_over = False

        result = TimestepResult(
            observation=self.get_observation(),
            reward=0,
            is_episode_end=self.is_game_over
        )

        self.record_timestep_stats(result)
        return result
Пример #6
0
def test_place_snake_given_position_on_map_places_correctly():
    field = Field(small_level_map)
    field.create_level()

    snake = Snake(field.find_snake_head(), length=3)
    field.place_snake(snake)
    assert str(field).split('\n') == [
        '#######',
        '#.....#',
        '#.....#',
        '#..S..#',
        '#..s..#',
        '#..s..#',
        '#######',
    ]
Пример #7
0
def test_grow_extends_head_in_current_direction():
    snake = Snake(Point(3, 5))

    snake.direction = SnakeDirection.NORTH
    snake.grow()
    assert snake.length == 4
    assert snake.head == (3, 4)
    assert snake.tail == (3, 7)

    snake.direction = SnakeDirection.WEST
    snake.grow()
    assert snake.length == 5
    assert snake.head == (2, 4)
    assert snake.tail == (3, 7)

    snake.direction = SnakeDirection.SOUTH
    snake.grow()
    assert snake.length == 6
    assert snake.head == (2, 5)
    assert snake.tail == (3, 7)

    snake.grow()
    snake.grow()
    snake.grow()

    snake.direction = SnakeDirection.EAST
    snake.grow()
    assert snake.length == 10
    assert snake.head == (3, 8)
    assert snake.tail == (3, 7)
    assert list(snake.body) == [
        (3, 8),
        (2, 8),
        (2, 7),
        (2, 6),
        (2, 5),
        (2, 4),
        (3, 4),
        (3, 5),
        (3, 6),
        (3, 7),
    ]
Пример #8
0
def test_place_snake_with_custom_length_respects_length():
    snake = Snake(Point(5, 1), length=5)
    assert snake.head == Point(5, 1)
    assert snake.tail == Point(5, 5)
    assert snake.length == 5
    assert list(snake.body) == [(5, 1), (5, 2), (5, 3), (5, 4), (5, 5)]
Пример #9
0
def test_turn_left_turns_relatively_to_current_direction():
    snake = Snake(Point(3, 5))

    snake.direction = SnakeDirection.NORTH
    snake.turn_left()
    assert snake.direction == SnakeDirection.WEST
    snake.turn_left()
    assert snake.direction == SnakeDirection.SOUTH
    snake.turn_left()
    assert snake.direction == SnakeDirection.EAST
    snake.turn_left()
    assert snake.direction == SnakeDirection.NORTH
Пример #10
0
class Environment(object):
    """
    Represents the RL environment for the Snake game that implements the game logic,
    provides rewards for the agent and keeps track of game statistics.
    """

    def __init__(self, config, verbose=1):
        """
        Create a new Snake RL environment.
        
        Args:
            config (dict): level configuration, typically found in JSON configs.  
            verbose (int): verbosity level:
                0 = do not write any debug information;
                1 = write a CSV file containing the statistics for every episode;
                2 = same as 1, but also write a full log file containing the state of each timestep.
        """
        self.field = Field(level_map=config['field'])
        self.snake = None
        self.fruit = None
        self.initial_snake_length = config['initial_snake_length']
        self.rewards = config['rewards']
        self.max_step_limit = config.get('max_step_limit', 1000)
        self.is_game_over = False

        self.timestep_index = 0
        self.current_action = None
        self.stats = EpisodeStatistics()
        self.verbose = verbose
        self.debug_file = None
        self.stats_file = None

    def seed(self, value):
        """ Initialize the random state of the environment to make results reproducible. """
        random.seed(value)
        np.random.seed(value)

    @property
    def observation_shape(self):
        """ Get the shape of the state observed at each timestep. """
        return self.field.size_x, self.field.size_y

    @property
    def num_actions(self):
        """ Get the number of actions the agent can take. """
        return len(ALL_SNAKE_ACTIONS)

    def new_episode(self):
        """ Reset the environment and begin a new episode. """
        self.field.create_level()
        self.stats.reset()
        self.timestep_index = 0

        self.snake = Snake(self.field.find_snake_head(), length=self.initial_snake_length)
        self.field.place_snake(self.snake)
        self.generate_fruit()
        self.current_action = None
        self.is_game_over = False

        result = TimestepResult(
            observation=self.get_observation(),
            reward=0,
            is_episode_end=self.is_game_over
        )

        self.record_timestep_stats(result)
        return result

    def record_timestep_stats(self, result):
        """ Record environment statistics according to the verbosity level. """
        timestamp = time.strftime('%Y%m%d-%H%M%S')

        # Write CSV header for the stats file.
        if self.verbose >= 1 and self.stats_file is None:
            self.stats_file = open(f'debug_stats/snake-env-{timestamp}.csv', 'w')
            stats_csv_header_line = self.stats.to_dataframe()[:0].to_csv(index=None)
            print(stats_csv_header_line, file=self.stats_file, end='', flush=True)

        # Create a blank debug log file.
        if self.verbose >= 2 and self.debug_file is None:
            self.debug_file = open(f'debug_log/snake-env-{timestamp}.log', 'w')

        self.stats.record_timestep(self.current_action, result)
        self.stats.timesteps_survived = self.timestep_index

        if self.verbose >= 2:
            print(result, file=self.debug_file)

        # Log episode stats if the appropriate verbosity level is set.
        if result.is_episode_end:
            if self.verbose >= 1:
                stats_csv_line = self.stats.to_dataframe().to_csv(header=False, index=None)
                print(stats_csv_line, file=self.stats_file, end='', flush=True)
            if self.verbose >= 2:
                print(self.stats, file=self.debug_file)

    def get_observation(self):
        """ Observe the state of the environment. """
        return np.copy(self.field._cells)

    def choose_action(self, action):
        """ Choose the action that will be taken at the next timestep. """

        self.current_action = action
        if action == SnakeAction.TURN_LEFT:
            self.snake.turn_left()
        elif action == SnakeAction.TURN_RIGHT:
            self.snake.turn_right()

    def timestep(self):
        """ Execute the timestep and return the new observable state. """

        self.timestep_index += 1
        reward = 0

        old_head = self.snake.head
        old_tail = self.snake.tail

        # Are we about to eat the fruit?
        if self.snake.peek_next_move() == self.fruit:
            self.snake.grow()
            self.generate_fruit()
            old_tail = None
            reward += self.rewards['ate_fruit'] * self.snake.length
            self.stats.fruits_eaten += 1

        # If not, just move forward.
        else:
            self.snake.move()
            reward += self.rewards['timestep']

        self.field.update_snake_footprint(old_head, old_tail, self.snake.head)

        # Hit a wall or own body?
        if not self.is_alive():
            if self.has_hit_wall():
                self.stats.termination_reason = 'hit_wall'
            if self.has_hit_own_body():
                self.stats.termination_reason = 'hit_own_body'

            self.field[self.snake.head] = CellType.SNAKE_HEAD
            self.is_game_over = True
            reward = self.rewards['died']

        # Exceeded the limit of moves?
        if self.timestep_index >= self.max_step_limit:
            self.is_game_over = True
            self.stats.termination_reason = 'timestep_limit_exceeded'

        result = TimestepResult(
            observation=self.get_observation(),
            reward=reward,
            is_episode_end=self.is_game_over
        )

        self.record_timestep_stats(result)
        return result

    def generate_fruit(self, position=None):
        """ Generate a new fruit at a random unoccupied cell. """
        if position is None:
            position = self.field.get_random_empty_cell()
        self.field[position] = CellType.FRUIT
        self.fruit = position

    def has_hit_wall(self):
        """ True if the snake has hit a wall, False otherwise. """
        return self.field[self.snake.head] == CellType.WALL

    def has_hit_own_body(self):
        """ True if the snake has hit its own body, False otherwise. """
        return self.field[self.snake.head] == CellType.SNAKE_BODY

    def is_alive(self):
        """ True if the snake is still alive, False otherwise. """
        return not self.has_hit_wall() and not self.has_hit_own_body()