class Agent: """ Reinforcement learning agent interacting with the game environment """ def __init__(self, model, memory_size, iterations_num, batch_size, game_board_size=20, random_start=True): self.model = model self.game_engine = Engine(game_board_size, random_start) self.replay_memory = Memory(memory_size) self.dataset = RLDataset(self.replay_memory, iterations_num, batch_size) self.game_state = self.get_game_state() def get_dataset(self): """ Returns dataset based on replay memory """ return self.dataset def get_game_state(self): state = self.game_engine.get_game_state() return torch.from_numpy(state).unsqueeze(0) def play_full_game(self, max_moves=1000): self.game_engine.reset() for _ in range(max_moves): state = self.get_game_state() out = self.model(state) action = torch.argmax(out).item() - 1 new_dir = (self.game_engine.direction + action) % 4 self.game_engine.next_round(new_dir) if not self.game_engine.alive: break points = self.game_engine.points self.game_engine.reset() return points def move(self, epsilon): """ Make single interaction with the game environment """ actions_num = self.model.actions_num rnd_action = random.random() <= epsilon if rnd_action: action_idx = random.randint(0, actions_num - 1) else: model_output = self.model(self.game_state) action_idx = torch.argmax(model_output).item() action = torch.zeros(actions_num) action[action_idx] = 1 action = action.unsqueeze(0) direction = self.translate_action(action_idx) reward, terminal = self.game_engine.next_round(direction) new_state = self.get_game_state() exp = (self.game_state, action, reward, new_state, terminal) self.replay_memory.append(exp) if terminal: self.game_engine.reset() new_state = self.get_game_state() self.game_state = new_state def translate_action(self, action): """ Translate action to new direction Actions: 0 - turn left 1 - go straight 2 - turn right """ direction = self.game_engine.direction new_direction = (direction + action - 1) % 4 return new_direction def warmup(self, num): """ Make <num> of random moves """ for _ in range(num): self.move(1)
class SnakeGame(Widget): def __init__(self, **kwargs): self.score_label = kwargs.pop('score_label', None) super(SnakeGame, self).__init__(**kwargs) self._keyboard = Window.request_keyboard(self._keyboard_closed, self) self._keyboard.bind(on_key_down=self._on_keyboard_down) self.engine = Engine(board_size=20) self.round_time = .05 self.model_path = './models/model.ptl' self.block_size = 10 self.board_length = (self.block_size + 1) * self.engine.board_size self.game_direction = self.engine.direction self.game_next_direction = self.game_direction def change_snake_direction(self, new_direction): directions = ['up', 'right', 'down', 'left'] new_d = directions.index(new_direction) if (self.game_direction + 2) % 4 != new_d: self.game_next_direction = new_d def update(self, dt): if self.engine.alive: self.game_direction = self.game_next_direction self.engine.next_round(self.game_direction) self.draw_board() self.update_score() def init_ai(self): self.model = SnakeNet.load_from_checkpoint(self.model_path) self.model.freeze() def update_with_model(self, dt): if self.engine.alive: state = self.engine.get_game_state() output = self.model(torch.from_numpy(state).unsqueeze(0)) action = torch.argmax(output).item() - 1 self.game_direction = (self.game_direction + action) % 4 self.engine.next_round(self.game_direction) self.draw_board() self.update_score() def update_score(self): score = self.engine.points self.score_label.text = f'Points: {int(score)}' def draw_board(self): self.canvas.clear() with self.canvas: border_width = 5 self.padding_x = (self.width - self.board_length) // 2 self.padding_y = (self.height - self.board_length) // 2 Rectangle(pos=(self.padding_x - border_width, self.padding_y), size=(border_width, self.board_length)) Rectangle(pos=(self.padding_x, self.padding_y - border_width), size=(self.board_length, border_width)) Rectangle(pos=(self.padding_x + self.board_length, self.padding_y), size=(border_width, self.board_length)) Rectangle(pos=(self.padding_x, self.padding_y + self.board_length), size=(self.board_length, border_width)) Color(.59, .91, .12) for s in self.engine.snake: x, y = s Rectangle(pos=(self.padding_x + x * (self.block_size + 1), self.padding_y + y * (self.block_size + 1)), size=(self.block_size, self.block_size)) x, y = self.engine.fruit Color(.93, .83, .05) Rectangle(pos=(self.padding_x + x * (self.block_size + 1), self.padding_y + y * (self.block_size + 1)), size=(self.block_size, self.block_size)) def _keyboard_closed(self): self._keyboard.unbind(on_key_down=self._on_keyboard_down) self._keyboard = None def _on_keyboard_down(self, keyboard, keycode, text, modifiers): if keycode[1] == 'w': self.change_snake_direction('up') elif keycode[1] == 's': self.change_snake_direction('down') elif keycode[1] == 'a': self.change_snake_direction('left') elif keycode[1] == 'd': self.change_snake_direction('right') elif keycode[1] == 'r': self.engine.reset()