Example #1
0
class Agent:
    """
    Reinforcement learning agent interacting with the game environment
    """
    def __init__(self,
                 model,
                 memory_size,
                 iterations_num,
                 batch_size,
                 game_board_size=20,
                 random_start=True):
        self.model = model
        self.game_engine = Engine(game_board_size, random_start)
        self.replay_memory = Memory(memory_size)

        self.dataset = RLDataset(self.replay_memory, iterations_num,
                                 batch_size)

        self.game_state = self.get_game_state()

    def get_dataset(self):
        """
        Returns dataset based on replay memory
        """

        return self.dataset

    def get_game_state(self):
        state = self.game_engine.get_game_state()

        return torch.from_numpy(state).unsqueeze(0)

    def play_full_game(self, max_moves=1000):
        self.game_engine.reset()

        for _ in range(max_moves):
            state = self.get_game_state()
            out = self.model(state)
            action = torch.argmax(out).item() - 1
            new_dir = (self.game_engine.direction + action) % 4
            self.game_engine.next_round(new_dir)

            if not self.game_engine.alive:
                break

        points = self.game_engine.points

        self.game_engine.reset()

        return points

    def move(self, epsilon):
        """
        Make single interaction with the game environment
        """
        actions_num = self.model.actions_num

        rnd_action = random.random() <= epsilon
        if rnd_action:
            action_idx = random.randint(0, actions_num - 1)
        else:
            model_output = self.model(self.game_state)
            action_idx = torch.argmax(model_output).item()
        action = torch.zeros(actions_num)
        action[action_idx] = 1
        action = action.unsqueeze(0)

        direction = self.translate_action(action_idx)
        reward, terminal = self.game_engine.next_round(direction)
        new_state = self.get_game_state()

        exp = (self.game_state, action, reward, new_state, terminal)
        self.replay_memory.append(exp)

        if terminal:
            self.game_engine.reset()
            new_state = self.get_game_state()
        self.game_state = new_state

    def translate_action(self, action):
        """
        Translate action to new direction

        Actions:
        0 - turn left
        1 - go straight
        2 - turn right

        """
        direction = self.game_engine.direction
        new_direction = (direction + action - 1) % 4

        return new_direction

    def warmup(self, num):
        """
        Make <num> of random moves
        """
        for _ in range(num):
            self.move(1)
Example #2
0
class SnakeGame(Widget):
    def __init__(self, **kwargs):
        self.score_label = kwargs.pop('score_label', None)
        super(SnakeGame, self).__init__(**kwargs)

        self._keyboard = Window.request_keyboard(self._keyboard_closed, self)
        self._keyboard.bind(on_key_down=self._on_keyboard_down)

        self.engine = Engine(board_size=20)
        self.round_time = .05

        self.model_path = './models/model.ptl'

        self.block_size = 10
        self.board_length = (self.block_size + 1) * self.engine.board_size

        self.game_direction = self.engine.direction
        self.game_next_direction = self.game_direction

    def change_snake_direction(self, new_direction):
        directions = ['up', 'right', 'down', 'left']
        new_d = directions.index(new_direction)

        if (self.game_direction + 2) % 4 != new_d:
            self.game_next_direction = new_d

    def update(self, dt):
        if self.engine.alive:
            self.game_direction = self.game_next_direction
            self.engine.next_round(self.game_direction)
            self.draw_board()
            self.update_score()

    def init_ai(self):
        self.model = SnakeNet.load_from_checkpoint(self.model_path)
        self.model.freeze()

    def update_with_model(self, dt):
        if self.engine.alive:
            state = self.engine.get_game_state()
            output = self.model(torch.from_numpy(state).unsqueeze(0))
            action = torch.argmax(output).item() - 1
            self.game_direction = (self.game_direction + action) % 4
            self.engine.next_round(self.game_direction)
            self.draw_board()
            self.update_score()

    def update_score(self):
        score = self.engine.points
        self.score_label.text = f'Points: {int(score)}'

    def draw_board(self):
        self.canvas.clear()

        with self.canvas:
            border_width = 5
            self.padding_x = (self.width - self.board_length) // 2
            self.padding_y = (self.height - self.board_length) // 2
            Rectangle(pos=(self.padding_x - border_width, self.padding_y),
                      size=(border_width, self.board_length))
            Rectangle(pos=(self.padding_x, self.padding_y - border_width),
                      size=(self.board_length, border_width))
            Rectangle(pos=(self.padding_x + self.board_length, self.padding_y),
                      size=(border_width, self.board_length))
            Rectangle(pos=(self.padding_x, self.padding_y + self.board_length),
                      size=(self.board_length, border_width))

            Color(.59, .91, .12)
            for s in self.engine.snake:
                x, y = s
                Rectangle(pos=(self.padding_x + x * (self.block_size + 1),
                               self.padding_y + y * (self.block_size + 1)),
                          size=(self.block_size, self.block_size))
            x, y = self.engine.fruit
            Color(.93, .83, .05)
            Rectangle(pos=(self.padding_x + x * (self.block_size + 1),
                           self.padding_y + y * (self.block_size + 1)),
                      size=(self.block_size, self.block_size))

    def _keyboard_closed(self):
        self._keyboard.unbind(on_key_down=self._on_keyboard_down)
        self._keyboard = None

    def _on_keyboard_down(self, keyboard, keycode, text, modifiers):
        if keycode[1] == 'w':
            self.change_snake_direction('up')
        elif keycode[1] == 's':
            self.change_snake_direction('down')
        elif keycode[1] == 'a':
            self.change_snake_direction('left')
        elif keycode[1] == 'd':
            self.change_snake_direction('right')
        elif keycode[1] == 'r':
            self.engine.reset()