示例#1
0
def show_gen(idx):
    global ginst

    genaction = GenAction(idx)
    env = SnakeEnv(use_simple=True)
    obs = env.reset()

    c = 0
    while True:
        c += 1
        action = genaction(obs)
        _, obs, done, _ = env(action)
        extra = '代数: {}'.format(ginst.n_gen)
        if c % 3 == 0:

            if ginst.n_gen:
                env.render(extra)
            else:
                env.render()

        # sleep(1 / ginst.n_gen)
        print(done, env.status.direction, env.life)
        if done:
            break

    sleep(1.5)
示例#2
0
文件: play.py 项目: kingyiusuen/snake
def run(display, retrain, num_episodes):
    pygame.init()
    env = SnakeEnv()
    agent = QlearningAgent(env)
    if not retrain:
        try:
            load_policy(agent)
        except:
            pass

    for _ in tqdm(range(num_episodes)):
        state = env.reset()
        done = False
        while not done:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    save_policy(agent)
                    pygame.quit()
                    sys.exit()
            if display:
                env.render()
            action = agent.act(state)
            next_state, reward, done = env.step(action)
            agent.update_q_value(state, reward, action, next_state, done)
            state = next_state
        agent.epsilon = max(agent.epsilon * agent.epsilon_decay_rate,
                            agent.min_epsilon)
    save_policy(agent)
示例#3
0
def fitness(solution, idx):
    global ginst
    if idx is None:
        return 0

    genaction = GenAction(idx)
    env = SnakeEnv(need_render=False,
                   use_simple=True,
                   set_life=int(GRID_WIDTH_NUM / 3))
    obs = env.reset()
    while True:
        action = genaction(obs)
        _, obs, done, _ = env(action)
        if done:
            break

    bl = len(env.status.snake_body)
    if bl < 10:
        fscore = (env.n_moves**2) * (2**bl)
    else:
        fscore = env.n_moves**2
        fscore *= 1024
        fscore *= (bl - 9)

    if fscore > ginst.max_score:
        ginst.max_score = fscore
        ginst.max_idx = idx
        ginst.params = env.n_moves, bl
        print('find new best: ', fscore, env.n_moves, bl)

    return fscore
示例#4
0
def show_gen(idx):
    global ginst

    genaction = GenAction(idx)
    env = SnakeEnv(use_simple=True)
    obs = env.reset()

    while True:
        action = genaction(obs)
        _, obs, done, _ = env(action)
        env.render()
        if done:
            break

    env.close()
示例#5
0
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

from snake_env import SnakeEnv

parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'test'], default='test')
args = parser.parse_args()

env = SnakeEnv((20, 20), 'standard')
env = DummyVecEnv([lambda: env])
model = PPO2(CnnPolicy, env, verbose=1)

if args.mode == 'train':

    model.learn(total_timesteps=20000)
    model.save('policy_baseline_snake')

elif args.mode == 'test':

    obs = env.reset()
    model.load('policy_baseline_snake')

    for i in range(1000):
        action, _states = model.predict(obs)
        obs, reward, done, info = env.step(action)
        env.render()
        if done:
            env.reset()

env.close()
class Application:
    def __init__(self, args):
        self.args = args
        self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
        self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
        
    def execute(self):
        if not self.args.human:
            if self.args.train_eps != 0:
                self.train()
            self.eval()
        self.show_games()

    def train(self):
        print("Train Phase:")
        self.agent.train()
        window = self.args.window
        self.points_results = []
        first_eat = True
        start = time.time()

        for game in range(1, self.args.train_eps + 1):
            state = self.env.get_state()
            dead = False
            action = self.agent.choose_action(state, 0, dead)
            while not dead:
                state, points, dead = self.env.step(action)

                # For debug convenience, you can check if your Q-table mathches ours for given setting of parameters
                # (see Debug Convenience part on homework 4 web page)
                if first_eat and points == 1:
                    self.agent.save_model(utils.CHECKPOINT)
                    first_eat = False

                action = self.agent.choose_action(state, points, dead)

    
            points = self.env.get_points()
            self.points_results.append(points)
            if game % self.args.window == 0:
                print(
                    "Games:", len(self.points_results) - window, "-", len(self.points_results), 
                    "Points (Average:", sum(self.points_results[-window:])/window,
                    "Max:", max(self.points_results[-window:]),
                    "Min:", min(self.points_results[-window:]),")",
                )
            self.env.reset()
        print("Training takes", time.time() - start, "seconds")
        self.agent.save_model(self.args.model_name)

    def eval(self):
        print("Evaling Phase:")
        self.agent.eval()
        self.agent.load_model(self.args.model_name)
        points_results = []
        start = time.time()

        for game in range(1, self.args.test_eps + 1):
            state = self.env.get_state()
            dead = False
            action = self.agent.choose_action(state, 0, dead)
            while not dead:
                state, points, dead = self.env.step(action)
                action = self.agent.choose_action(state, points, dead)
            points = self.env.get_points()
            points_results.append(points)
            self.env.reset()

        print("Testing takes", time.time() - start, "seconds")
        print("Number of Games:", len(points_results))
        print("Average Points:", sum(points_results)/len(points_results))
        print("Max Points:", max(points_results))
        print("Min Points:", min(points_results))

    def show_games(self):
        print("Display Games")
        self.env.display()
        pygame.event.pump()
        self.agent.eval()
        points_results = []
        end = False
        for game in range(1, self.args.show_eps + 1):
            state = self.env.get_state()
            dead = False
            action = self.agent.choose_action(state, 0, dead)
            count = 0
            while not dead:
                count +=1
                pygame.event.pump()
                keys = pygame.key.get_pressed()
                if keys[K_ESCAPE] or self.check_quit():
                    end = True
                    break
                state, points, dead = self.env.step(action)
                # Qlearning agent
                if not self.args.human:
                    action = self.agent.choose_action(state, points, dead)
                # for human player
                else:
                    for event in pygame.event.get():
                        if event.type == pygame.KEYDOWN:
                            if event.key == pygame.K_UP:
                                action = 2
                            elif event.key == pygame.K_DOWN:
                                action = 3
                            elif event.key == pygame.K_LEFT:
                                action = 1
                            elif event.key == pygame.K_RIGHT:
                                action = 0
            if end:
                break
            self.env.reset()
            points_results.append(points)
            print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points)
        if len(points_results) == 0:
            return
        print("Average Points:", sum(points_results)/len(points_results))

    def check_quit(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                return True
        return False
示例#7
0
    ])
    model.compile(loss='mean_squared_error',
                  optimizer=tf.keras.optimizers.Adam(0.001))

    if weights:
        model.load_weights(weights)
    return model


def act(state, epsilon=0.1, step=0):
    """预测动作"""
    return np.argmax(model.predict(np.array([state]))[0])


env = SnakeEnv()
model = create_model("weights.hdf5")
for i in range(1000):
    state = env.reset()
    step = 0
    while True:
        env.render()
        model.predict(np.array([state]))

        done = True
        #next_state, reward, done, _ = env.step(0)
        step += 1
        if done:
            print('Game', i + 1, '      Score:', env.score)
            break
env.close()
示例#8
0
                  (-1 + min_epsilon) / n_train), min_epsilon)
else:
    epsilons = np.arange(1, min_epsilon, (-1 + min_epsilon) / n_train)

# size of memory
for i_train in range(n_train):
    epsilon = epsilons[i_train]
    res.epsilon = epsilon
    print("Training: round {}, epsilon = {}".format(i_train, round(epsilon,
                                                                   2)))
    lengths_i_train = []
    scores_i_train = []
    for i_episode in range(n_episodes):
        i = 0
        done = False
        grid = env.reset()
        grid = grid.reshape((1, n_channels, env.nrow, env.ncol))
        t0 = time.time()
        while i < imax:
            i += 1
            source = grid.copy()
            if epsilon >= np.random.rand():
                action = np.random.randint(4)
            else:
                action = np.argmax(model.predict(source))
            grid, reward, done = env.step(action)
            grid = grid.reshape((1, n_channels, env.nrow, env.ncol))

            observation = {'source':source, 'action':action, \
                           'dest':grid, 'reward':reward,'final':done}
            res.memory.append(observation)
示例#9
0
def play(record=0, no_render=False):
    env = SnakeEnv(need_render=not no_render, alg='最短路径')
    obs = env.reset()
    env.render()
    input()
    x, y = [], []

    directions = {
        (-1, 0): env.right,
        (1, 0): env.left,
        (0, -1): env.down,
        (0, 1): env.up
    }

    need_record = True if record else False
    new_dst = None
    origin_dst = None
    # counter = 20
    use_random = False
    while True:
        if not record and not no_render:
            env.render()
        src = np.where(obs == 2)
        src = int(src[1]), int(src[0])
        dst = np.where(obs == -1)
        dst = int(dst[0]), int(dst[1])

        if new_dst is not None:
            paths = bfs(obs, start=src, dst=new_dst)
        else:
            paths = bfs(obs, start=src, dst=dst)

        if paths is None:
            # origin_dst = dst
            # new_dst = (
            #     np.random.randint(0, obs.shape[0]),
            #     np.random.randint(0, obs.shape[1]),
            # )
            # counter -= 1
            # if counter <= 0:
            #     print('score: ', env.status.score)
            #     new_dst = None
            #     origin_dst = None
            #     counter = 20
            #     obs = env.reset()
            # continue
            use_random = True
        else:
            new_dst = None
            if new_dst is not None and paths[1] == new_dst:
                new_dst = None
                if origin_dst is not None:
                    dst = origin_dst
                    origin_dst = None
                    # counter = 20
                    continue

        # if counter <= 0 or paths is None or len(paths) <= 1:
        #     print('score: ', env.status.score)
        #     obs = env.reset()
        #     continue

        if use_random:
            action = np.random.randint(0, 4)
            use_random = False
        else:
            dst = paths[1]
            dire = src[0] - dst[0], src[1] - dst[1]
            action = directions[dire]
        # import ipdb
        # ipdb.set_trace()
        if need_record:
            x.append(obs)
            y.append(action)
            if len(y) >= record:
                return x, y

            if len(y) % 1000 == 0:
                print(len(y))

        _, obs, done, _ = env(action)
        # counter = 20

        if done:
            print(env.status.score)
            sleep(1.5)
            break
    if not record and not no_render:
        env.render()

    env.close()
示例#10
0
def draw_graph():
    print(GRID_HEIGHT_NUM, GRID_WIDTH_NUM)
    # input()
    graph = build_graph(row=GRID_HEIGHT_NUM, col=GRID_WIDTH_NUM)
    total_graph = deepcopy(graph)
    env = SnakeEnv(set_life=100000, alg='HC + BFS', no_sight_disp=True)
    env.reset()
    sleep(1)

    graph, flags = deletion(graph, env)
    for sp in graph:
        for ep in graph[sp]:
            if flags[(sp, ep)]:
                # print(sp, ep)
                env.draw_connection(sp, ep, width=4)
            # env.render()

    import pygame
    pygame.display.update()
    pre_len = None
    while True:
        sd_len = destroy(graph, total_graph, flags, env=env)
        print('sd: ', sd_len)
        if pre_len is not None and pre_len == sd_len:
            global MAX_DEPTH
            print('+1')
            MAX_DEPTH += 1

        pre_len = sd_len

        show_graph(graph, flags, env)
        if not sd_len:
            break

    sleep(1)

    show_graph(graph, flags, env)
    counter = 0
    while not connector(graph, total_graph, flags, env):
        counter += 1
        print('counter: ', counter)

    sleep(1)

    for sp in graph:
        for ep in graph[sp]:
            if flags[(sp, ep)]:
                env.draw_connection(sp, ep, color=(0xff, 0xff, 0), width=4)

    import pygame
    show_graph(graph, flags, env)
    circle = get_list_circle(graph)
    print(circle)
    pos_encoder = {pos: i for i, pos in enumerate(circle)}
    # pos_decoder = {i: pos for i, pos in enumerate(circle)}
    pos_xy_decoder = {
        i: (pos % GRID_WIDTH_NUM, pos // GRID_WIDTH_NUM)
        for i, pos in enumerate(circle)
    }
    pos_xy_encoder = {(pos % GRID_WIDTH_NUM, pos // GRID_WIDTH_NUM): i
                      for i, pos in enumerate(circle)}
    obs = env.reset()
    c = 0
    while True:
        c += 1

        if len(env.status.snake_body) < 15:
            remainder = 20
        elif len(env.status.snake_body) < 30:
            remainder = 20
        elif len(env.status.snake_body) < 60:
            remainder = 30
        elif len(env.status.snake_body) < 90:
            remainder = 30
        elif len(env.status.snake_body) < 120:
            remainder = 40
        elif len(env.status.snake_body) < 150:
            remainder = 80
        elif len(env.status.snake_body) < 300:
            remainder = 100
        elif len(env.status.snake_body) < (GRID_WIDTH_NUM * GRID_HEIGHT_NUM -
                                           10):
            remainder = 30
        else:
            remainder = 5
        bfs_action, dst = dfs_policy(obs, env)
        bfs_dst_idx = 100000000
        if dst:
            bfs_dst_idx = pos_xy_encoder[dst]
        head = env.status.snake_body[0]
        head_pos, tail_pos = pos_xy_encoder[head], pos_xy_encoder[
            env.status.snake_body[-1]]
        head_idx, tail_idx = pos_xy_encoder[head], pos_xy_encoder[
            env.status.snake_body[-1]]

        hc_next_pos = pos_xy_decoder[(head_pos + 1) % len(graph)]

        directions = {
            (-1, 0): env.right,
            (1, 0): env.left,
            (0, -1): env.down,
            (0, 1): env.up
        }
        dire = head[0] - hc_next_pos[0], head[1] - hc_next_pos[1]
        print(head, hc_next_pos, dst, dst not in env.status.snake_body[:-1])
        print(head_idx, tail_idx, bfs_dst_idx)
        action = directions[dire]
        if not env.status.food_pos:
            show_graph(graph,
                       flags,
                       env,
                       update=True,
                       width=1,
                       extra='倍速: {} X'.format(remainder * 5))
            break

        food_idx = pos_xy_encoder[env.status.food_pos]
        if bfs_action:
            print(food_idx, bfs_dst_idx, head_idx, tail_idx)
            print(rel_pos(food_idx, tail_idx, len(graph)),
                  rel_pos(bfs_dst_idx, tail_idx, len(graph)),
                  rel_pos(head_idx, tail_idx, len(graph)),
                  rel_pos(tail_idx, tail_idx, len(graph)))
            if rel_pos(food_idx, tail_idx, len(graph)) >= rel_pos(
                    bfs_dst_idx, tail_idx, len(graph)) >= rel_pos(
                        head_idx, tail_idx, len(graph)) >= rel_pos(
                            tail_idx, tail_idx, len(graph)):
                action = bfs_action
                pass

        reward, obs, done, _ = env(action)
        if done:
            show_graph(graph,
                       flags,
                       env,
                       update=True,
                       width=1,
                       extra='倍速: {} X'.format(remainder * 5))
            print(done)
            break
        # env.screen.blit(env.background, (0, 0))

        if c % remainder == 0:
            show_graph(graph,
                       flags,
                       env,
                       update=True,
                       width=1,
                       extra='倍速: {} X'.format(remainder * 5))
        # env.render(blit=False)

    show_graph(graph, flags, env, update=True, width=1)
    sleep(10)
    input()
示例#11
0
# -*- coding: utf-8 -*-

import time
from snake_env import SnakeEnv

from simple_mlp import Policy


def play(env, policy):
    obs = env.reset()
    while True:
        action = policy.predict(obs)
        reward, obs, done, _ = env(action)
        env.render()
        if done:
            obs = env.reset()
            time.sleep(1)
        # time.sleep(0.05)


if __name__ == '__main__':
    policy = Policy(pre_trained='pretrained/mlp-v0.joblib')
    env = SnakeEnv(alg='MLP')
    env.reset()
    env.render()
    input()
    play(env, policy)