コード例 #1
0
ファイル: human.py プロジェクト: jofa974/Snake
    def play(self):
        self.snake = Snake()
        self.apple = Apple()
        score = 0

        while not self.snake.dead:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.snake.dead = True
                if event.type == pygame.KEYDOWN and event.key in ui.CONTROLS:
                    self.snake.change_direction(event.key)
                if event.type == pygame.KEYDOWN and event.key == pygame.K_SPACE:
                    self.env.take_screenshot()

            self.snake.move()

            self.snake.detect_collisions()

            if self.snake.eat(self.apple):
                self.snake.grow()
                self.snake.update()
                self.apple.new_random()
                score += 1

            score_text = "Score: {}".format(score)
            self.env.draw_everything(score_text, [self.snake, self.apple])

            time.sleep(150.0 / 1000.0)

        final_text = "GAME OVER! Your score is {}".format(score)

        self.env.draw_everything(final_text, [self.snake, self.apple])

        time.sleep(2)
コード例 #2
0
class Random(Brain):
    def __init__(self):
        super().__init__(do_display=True)
        self.env.set_caption("Snake: Random mode")

    def play(self):
        self.apple = Apple()
        self.snake = Snake()
        score = 0

        while not self.snake.dead:

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.snake.dead = True

            if random.randint(0, 1):
                if self.snake.speed[0] == 0:
                    key = random.choice([pygame.K_LEFT, pygame.K_RIGHT])
                else:
                    key = random.choice([pygame.K_UP, pygame.K_DOWN])
                self.snake.change_direction(key)

            self.snake.move()

            self.snake.detect_collisions()

            if self.snake.eat(self.apple):
                self.snake.grow()
                self.apple.new_random()
                score += 1

            score_text = "Score: {}".format(score)
            self.env.draw_everything(score_text, [self.snake, self.apple])

            time.sleep(100.0 / 1000.0)

        final_text = "GAME OVER! The random score is {}".format(score)

        self.env.draw_everything(final_text, [self.snake, self.apple])

        time.sleep(2)
コード例 #3
0
ファイル: bfs.py プロジェクト: jofa974/Snake
    def play(self):
        self.apple = Apple()
        self.snake = Snake()
        score = 0
        if self.do_display:
            self.env.set_caption("Snake: BFS mode")

        while not self.snake.dead:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.snake.dead = True

            if not self.moves:
                self.make_grid_map()
                path = self.BFS()
                self.get_moves_from_path(path)

            if self.moves:
                next_move = self.moves.pop(0)
            else:
                next_move = "forward"

            if next_move in ui.CONTROLS:
                self.snake.change_direction(next_move)

            self.snake.move()

            self.snake.detect_collisions()

            if self.snake.eat(self.apple):
                self.snake.grow()
                self.snake.update()
                self.apple.new_random()
                score += 1

            if self.do_display:
                score_text = "Score: {}".format(score)
                self.env.draw_everything(score_text, [self.snake, self.apple])
                time.sleep(50.0 / 1000.0)

        return score
コード例 #4
0
class DQN(Brain):
    def __init__(
        self,
        batch_size,
        gamma,
        memory_size,
        do_display=False,
        learning=True,
    ):
        super().__init__(do_display=do_display)
        self.model = None
        self.gamma = gamma
        self.reward_window = []
        self.memory = ReplayMemory(memory_size)
        self.batch_size = batch_size
        self.optimizer = None
        self.steps = 0
        self.last_state = None
        self.last_action = 0
        self.last_reward = 0
        self.brain_file = "last_brain.pth"
        self.loss_history = []
        self.mean_reward_history = []
        self.list_of_rewards = []
        self.learning = learning

    @abstractmethod
    def get_input_data(self):
        raise NotImplementedError

    def select_action(self, state, epsilon):
        probs = F.softmax(self.model(state), dim=1)
        if self.learning and random.random() < epsilon:
            action = probs.multinomial(num_samples=1)[0][0]
        else:
            action = probs.argmax()
        return action.item()

    def action2direction_key(self, action):
        directions = ["forward", "left", "right"]
        if directions[action] == "forward":
            return pygame.K_SPACE
        elif directions[action] == "left":
            if self.snake.speed[0] > 0:
                return pygame.K_UP
            if self.snake.speed[0] < 0:
                return pygame.K_DOWN
            if self.snake.speed[1] > 0:
                return pygame.K_RIGHT
            if self.snake.speed[1] < 0:
                return pygame.K_LEFT
        elif directions[action] == "right":
            if self.snake.speed[0] > 0:
                return pygame.K_DOWN
            if self.snake.speed[0] < 0:
                return pygame.K_UP
            if self.snake.speed[1] > 0:
                return pygame.K_LEFT
            if self.snake.speed[1] < 0:
                return pygame.K_RIGHT

    def mean_reward(self):
        return np.mean(self.list_of_rewards)

    def save(self, filename=None):
        if filename is None:
            filename = self.brain_file
        torch.save(
            {
                "state_dict": self.model.state_dict(),
                "optimizer": self.optimizer.state_dict(),
            },
            filename,
        )

    def save_best(self):
        self.save(filename="best_brain.pth")

    def load(self, filename=None):
        if filename is None:
            filename = self.brain_file
        print("Loading brain stored in {}".format(filename))
        if os.path.isfile(filename):
            # print("=> loading checkpoint ...")
            checkpoint = torch.load(filename)
            self.model.load_state_dict(checkpoint["state_dict"])
            self.optimizer.load_state_dict(checkpoint["optimizer"])
            # print("done !")
        else:
            print("no checkpoint found ...")

    def load_best(self):
        self.load(filename="best_brain.pth")

    def play(self, max_move=-1, init_training_data=None, epsilon=0):
        self.snake = Snake()

        forbidden_positions = self.snake.get_body_position_list()
        if init_training_data:
            training_data = itertools.cycle(init_training_data)
            self.apple = Apple(forbidden=forbidden_positions,
                               xy=next(training_data))
        else:
            self.apple = Apple(forbidden=forbidden_positions)

        nb_moves = 0
        nb_apples = 0

        while (not self.snake.dead) and (nb_moves < max_move):

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    sys.exit()

            nb_moves += 1
            self.steps += 1

            score_text = "Score: {}".format(nb_apples)
            if self.do_display:
                self.env.draw_everything(score_text, [self.snake, self.apple],
                                         flip=True)
                time.sleep(0.1)

            last_signal = self.get_input_data()

            next_action = self.update(
                self.last_reward,
                last_signal,
                nb_steps=nb_moves,
                epsilon=epsilon,
            )

            next_move = self.action2direction_key(next_action)
            if next_move in ui.CONTROLS:
                self.snake.change_direction(next_move)

            prev_dist = self.snake.get_distance_to_target(
                self.snake.get_position(0), self.apple.get_position(), norm=2)
            self.snake.move()
            new_dist = self.snake.get_distance_to_target(
                self.snake.get_position(0), self.apple.get_position(), norm=2)

            if new_dist < prev_dist:
                self.last_reward = (prev_dist - new_dist) / (
                    np.sqrt(ui.X_GRID**2 + ui.Y_GRID**2))
            else:
                self.last_reward = -0.7

            self.snake.detect_collisions()
            if self.snake.dead:
                self.last_reward = -1

            if self.snake.eat(self.apple):
                nb_apples += 1
                self.snake.grow()
                self.snake.update()
                forbidden_positions = self.snake.get_body_position_list()
                if init_training_data:
                    x, y = next(training_data)
                    self.apple.new(x, y, forbidden=forbidden_positions)
                else:
                    self.apple.new_random(forbidden=forbidden_positions)
                self.last_reward = 1

            self.list_of_rewards.append(self.last_reward)

        if self.learn and nb_moves < max_move:
            # Restart game and try to finish epoch
            self.play(
                max_move=max_move - nb_moves,
                init_training_data=init_training_data,
                epsilon=epsilon,
            )

        return nb_apples

    def update(self, reward, new_signal, nb_steps=-1, epsilon=-1.0):
        new_state = torch.Tensor(new_signal).unsqueeze(0)

        if self.learning:
            self.memory.push((
                self.last_state,
                new_state,
                torch.LongTensor([int(self.last_action)]),
                torch.Tensor([self.last_reward]),
            ))

        action = self.select_action(new_state, epsilon)

        self.last_action = action
        self.last_state = new_state
        self.last_reward = reward
        self.reward_window.append(reward)
        if len(self.reward_window) > self.batch_size:
            del self.reward_window[0]
        return action

    def learn(self):
        (
            batch_state,
            batch_next_state,
            batch_action,
            batch_reward,
        ) = self.memory.sample(self.batch_size)

        outputs = (self.model(batch_state).gather(
            1, batch_action.unsqueeze(1)).squeeze(1))
        next_outputs = self.model(batch_next_state).detach().max(1)[0]
        targets = batch_reward + self.gamma * next_outputs
        loss = self.loss(outputs, targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        if loss.item() > 1.0e7:
            print("outputs {}".format(outputs))
            print("targets {}".format(targets))
        return loss.item()
コード例 #5
0
    def play(self, max_move=-1, init_training_data=None, epsilon=0):
        self.snake = Snake()

        forbidden_positions = self.snake.get_body_position_list()
        if init_training_data:
            training_data = itertools.cycle(init_training_data)
            self.apple = Apple(forbidden=forbidden_positions,
                               xy=next(training_data))
        else:
            self.apple = Apple(forbidden=forbidden_positions)

        nb_moves = 0
        nb_apples = 0

        while (not self.snake.dead) and (nb_moves < max_move):

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    sys.exit()

            nb_moves += 1
            self.steps += 1

            score_text = "Score: {}".format(nb_apples)
            if self.do_display:
                self.env.draw_everything(score_text, [self.snake, self.apple],
                                         flip=True)
                time.sleep(0.1)

            last_signal = self.get_input_data()

            next_action = self.update(
                self.last_reward,
                last_signal,
                nb_steps=nb_moves,
                epsilon=epsilon,
            )

            next_move = self.action2direction_key(next_action)
            if next_move in ui.CONTROLS:
                self.snake.change_direction(next_move)

            prev_dist = self.snake.get_distance_to_target(
                self.snake.get_position(0), self.apple.get_position(), norm=2)
            self.snake.move()
            new_dist = self.snake.get_distance_to_target(
                self.snake.get_position(0), self.apple.get_position(), norm=2)

            if new_dist < prev_dist:
                self.last_reward = (prev_dist - new_dist) / (
                    np.sqrt(ui.X_GRID**2 + ui.Y_GRID**2))
            else:
                self.last_reward = -0.7

            self.snake.detect_collisions()
            if self.snake.dead:
                self.last_reward = -1

            if self.snake.eat(self.apple):
                nb_apples += 1
                self.snake.grow()
                self.snake.update()
                forbidden_positions = self.snake.get_body_position_list()
                if init_training_data:
                    x, y = next(training_data)
                    self.apple.new(x, y, forbidden=forbidden_positions)
                else:
                    self.apple.new_random(forbidden=forbidden_positions)
                self.last_reward = 1

            self.list_of_rewards.append(self.last_reward)

        if self.learn and nb_moves < max_move:
            # Restart game and try to finish epoch
            self.play(
                max_move=max_move - nb_moves,
                init_training_data=init_training_data,
                epsilon=epsilon,
            )

        return nb_apples
コード例 #6
0
    def play(self, max_move=-1, dump=False, training_data=None):
        self.snake = Snake()

        forbidden_positions = self.snake.get_body_position_list()
        if training_data:
            training_data = itertools.cycle(training_data)
            self.apple = Apple(forbidden=forbidden_positions,
                               xy=next(training_data))
        else:
            self.apple = Apple(forbidden=forbidden_positions)

        score = 0
        fitness = 0
        nb_moves = 0

        if self.do_display:
            matplotlib.use("Agg")
            self.env.set_caption(
                "Snake: Custom Neural Network optimized with a Genetic Algorithm"
            )
            fig = plt.figure(figsize=[3, 3], dpi=100)

        while not self.snake.dead:

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.snake.dead = True

            # Feed the NN with input data
            input_data = self.get_input_data()
            self.nn.forward(input_data)

            # Take decision
            next_direction = self.nn.decide_direction()

            # Deduce which key to press based on next direction
            next_move = self.get_move_from_direction(next_direction)
            if next_move in ui.CONTROLS:
                self.snake.change_direction(next_move)

            self.snake.move()

            self.snake.detect_collisions()
            if self.snake.dead:
                fitness -= 10

            if self.snake.eat(self.apple):
                self.snake.grow()
                self.snake.update()
                forbidden_positions = self.snake.get_body_position_list()
                if training_data:
                    x, y = next(training_data)
                    self.apple.new(x, y, forbidden=forbidden_positions)
                else:
                    self.apple.new_random(forbidden=forbidden_positions)
                score += 1

            if self.do_display:
                score_text = "Score: {}".format(score)
                self.env.draw_everything(score_text, [self.snake, self.apple],
                                         flip=False)
                # self.nn.plot(fig)
                self.env.make_surf_from_figure_on_canvas(fig)
                time.sleep(0.01 / 1000.0)

            nb_moves += 1
            fitness = (nb_moves +
                       (math.pow(2, score) + math.pow(score, 2.1) * 500) -
                       (math.pow(score, 1.2) * math.pow(0.25 * score, 1.3)))
            if max_move > 0 and nb_moves >= max_move:
                break

        if dump:
            self.nn.dump_data(self.gen_id, fitness)

        return score, fitness
コード例 #7
0
class NN_GA(Brain):
    """
    Class that will play the game with a neural network optimized
    using a genetic algorithm.
    """
    def __init__(self, do_display, gen_id=(-1, -1), dna=None, hidden_nb=[4]):
        super().__init__(do_display=do_display)
        self.nn = ANN(gen_id, dna, hidden_nb=hidden_nb)
        self.gen_id = gen_id

    def play(self, max_move=-1, dump=False, training_data=None):
        self.snake = Snake()

        forbidden_positions = self.snake.get_body_position_list()
        if training_data:
            training_data = itertools.cycle(training_data)
            self.apple = Apple(forbidden=forbidden_positions,
                               xy=next(training_data))
        else:
            self.apple = Apple(forbidden=forbidden_positions)

        score = 0
        fitness = 0
        nb_moves = 0

        if self.do_display:
            matplotlib.use("Agg")
            self.env.set_caption(
                "Snake: Custom Neural Network optimized with a Genetic Algorithm"
            )
            fig = plt.figure(figsize=[3, 3], dpi=100)

        while not self.snake.dead:

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.snake.dead = True

            # Feed the NN with input data
            input_data = self.get_input_data()
            self.nn.forward(input_data)

            # Take decision
            next_direction = self.nn.decide_direction()

            # Deduce which key to press based on next direction
            next_move = self.get_move_from_direction(next_direction)
            if next_move in ui.CONTROLS:
                self.snake.change_direction(next_move)

            self.snake.move()

            self.snake.detect_collisions()
            if self.snake.dead:
                fitness -= 10

            if self.snake.eat(self.apple):
                self.snake.grow()
                self.snake.update()
                forbidden_positions = self.snake.get_body_position_list()
                if training_data:
                    x, y = next(training_data)
                    self.apple.new(x, y, forbidden=forbidden_positions)
                else:
                    self.apple.new_random(forbidden=forbidden_positions)
                score += 1

            if self.do_display:
                score_text = "Score: {}".format(score)
                self.env.draw_everything(score_text, [self.snake, self.apple],
                                         flip=False)
                # self.nn.plot(fig)
                self.env.make_surf_from_figure_on_canvas(fig)
                time.sleep(0.01 / 1000.0)

            nb_moves += 1
            fitness = (nb_moves +
                       (math.pow(2, score) + math.pow(score, 2.1) * 500) -
                       (math.pow(score, 1.2) * math.pow(0.25 * score, 1.3)))
            if max_move > 0 and nb_moves >= max_move:
                break

        if dump:
            self.nn.dump_data(self.gen_id, fitness)

        return score, fitness

    def get_move_from_direction(self, direction):
        if direction == "forward":
            return "forward"
        if direction == "left":
            if self.snake.speed[0] > 0:
                return pygame.K_UP
            if self.snake.speed[0] < 0:
                return pygame.K_DOWN
            if self.snake.speed[1] > 0:
                return pygame.K_RIGHT
            if self.snake.speed[1] < 0:
                return pygame.K_LEFT
        if direction == "right":
            if self.snake.speed[0] > 0:
                return pygame.K_DOWN
            if self.snake.speed[0] < 0:
                return pygame.K_UP
            if self.snake.speed[1] > 0:
                return pygame.K_LEFT
            if self.snake.speed[1] < 0:
                return pygame.K_RIGHT

    def get_input_data(self):
        apple_pos = self.apple.get_position()
        input_data = [
            self.snake.is_clear_ahead(),
            self.snake.is_clear_left(),
            self.snake.is_clear_right(),
            self.snake.is_food_ahead(apple_pos),
            self.snake.is_food_left(apple_pos),
            self.snake.is_food_right(apple_pos),
        ]
        return input_data
コード例 #8
0
ファイル: bfs.py プロジェクト: jofa974/Snake
class BFS(Brain):
    def __init__(self, do_display):
        super().__init__(do_display=do_display)
        self.moves = []
        self.grid = []
        for y in range(ui.Y_GRID):
            grid = []
            for x in range(ui.X_GRID):
                grid.append(
                    pygame.Rect(
                        x * ui.BASE_SIZE,
                        y * ui.BASE_SIZE,
                        ui.BASE_SIZE,
                        ui.BASE_SIZE,
                    ))
            self.grid.append(grid)

    def play(self):
        self.apple = Apple()
        self.snake = Snake()
        score = 0
        if self.do_display:
            self.env.set_caption("Snake: BFS mode")

        while not self.snake.dead:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.snake.dead = True

            if not self.moves:
                self.make_grid_map()
                path = self.BFS()
                self.get_moves_from_path(path)

            if self.moves:
                next_move = self.moves.pop(0)
            else:
                next_move = "forward"

            if next_move in ui.CONTROLS:
                self.snake.change_direction(next_move)

            self.snake.move()

            self.snake.detect_collisions()

            if self.snake.eat(self.apple):
                self.snake.grow()
                self.snake.update()
                self.apple.new_random()
                score += 1

            if self.do_display:
                score_text = "Score: {}".format(score)
                self.env.draw_everything(score_text, [self.snake, self.apple])
                time.sleep(50.0 / 1000.0)

        return score

    def make_grid_map(self):
        self.grid_map = [[True for _ in range(ui.X_GRID)]
                         for _ in range(ui.Y_GRID)]
        self.grid_map[0][:] = [False for _ in range(ui.X_GRID)]
        self.grid_map[-1][:] = [False for _ in range(ui.X_GRID)]
        for i in range(ui.Y_GRID):
            self.grid_map[i][0] = False
            self.grid_map[i][-1] = False

    def BFS(self):
        height = ui.Y_GRID
        width = ui.X_GRID
        start = self.snake.get_position(0)
        end = self.apple.get_position()
        queue = deque([[start]])
        visited = set((start))
        for i in range(1, len(self.snake.body_list)):
            visited.add((self.snake.get_position(i)))
        while queue:
            path = queue.popleft()
            nextp = path[-1]
            if nextp == end:
                return path
            neighbors = (
                (nextp[0] + 1, nextp[1]),
                (nextp[0] - 1, nextp[1]),
                (nextp[0], nextp[1] + 1),
                (nextp[0], nextp[1] - 1),
            )
            for neighbor in neighbors:
                xn = neighbor[0]
                yn = neighbor[1]
                ingrid = 0 <= xn < width and 0 <= yn < height
                isvisited = (neighbor[0], neighbor[1]) in visited
                if ingrid and (not isvisited) and self.grid_map[yn][xn]:
                    queue.append(path + [(neighbor[0], neighbor[1])])
                    visited.add((neighbor[0], neighbor[1]))

    def get_moves_from_path(self, path):
        direction = self.snake.speed
        self.moves = []
        if path and len(path) > 1:
            current = path[0]
            for place in path[1:]:
                if place[0] - current[0] == 1:
                    if direction[1] == 0:
                        self.moves.append("forward")
                    else:
                        self.moves.append(pygame.K_RIGHT)
                        direction = (ui.BASE_SPEED, 0)
                if place[0] - current[0] == -1:
                    if direction[1] == 0:
                        self.moves.append("forward")
                    else:
                        self.moves.append(pygame.K_LEFT)
                        direction = (-ui.BASE_SPEED, 0)
                if place[1] - current[1] == 1:
                    if direction[0] == 0:
                        self.moves.append("forward")
                    else:
                        self.moves.append(pygame.K_DOWN)
                        direction = (0, ui.BASE_SPEED)
                if place[1] - current[1] == -1:
                    if direction[0] == 0:
                        self.moves.append("forward")
                    else:
                        self.moves.append(pygame.K_UP)
                        direction = (0, -ui.BASE_SPEED)
                current = place

    def draw_grid(self):
        for rect in itertools.chain.from_iterable(self.grid):
            pygame.draw.rect(self.screen, ui.BROWN, rect, 3)
コード例 #9
0
from components.apple import Apple
from components.snake import Snake
from ui import BASE_SPEED

apple = Apple()
apple.new(20, 20)
apple_pos = apple.get_position()


def test_get_distance_to_target():
    snake = Snake(20, 40, (0, -BASE_SPEED))
    dist = snake.get_distance_to_target(snake.get_position(0),
                                        apple_pos,
                                        norm=1)
    assert dist == 20

    snake = Snake(0, 20, (0, BASE_SPEED))
    dist = snake.get_distance_to_target(snake.get_position(0),
                                        apple_pos,
                                        norm=1)
    assert dist == 20

    snake = Snake(20, 20, (0, BASE_SPEED))
    dist = snake.get_distance_to_target(snake.get_position(0),
                                        apple_pos,
                                        norm=1)
    assert dist == 0

    snake = Snake(0, 0, (0, BASE_SPEED))
    dist = snake.get_distance_to_target(snake.get_position(0),
                                        apple_pos,