示例#1
0
def fitness(solution, idx):
    global ginst
    if idx is None:
        return 0

    genaction = GenAction(idx)
    env = SnakeEnv(need_render=False,
                   use_simple=True,
                   set_life=int(GRID_WIDTH_NUM / 3))
    obs = env.reset()
    while True:
        action = genaction(obs)
        _, obs, done, _ = env(action)
        if done:
            break

    bl = len(env.status.snake_body)
    if bl < 10:
        fscore = (env.n_moves**2) * (2**bl)
    else:
        fscore = env.n_moves**2
        fscore *= 1024
        fscore *= (bl - 9)

    if fscore > ginst.max_score:
        ginst.max_score = fscore
        ginst.max_idx = idx
        ginst.params = env.n_moves, bl
        print('find new best: ', fscore, env.n_moves, bl)

    return fscore
示例#2
0
    def __init__(self,
                 n_solutions=2000,
                 n_input=24,
                 n_output=4,
                 hiddens=[18, 18],
                 n_gen=0):
        self.n_solutions = n_solutions
        self.n_input = n_input
        self.n_output = n_output
        self.hiddens = hiddens

        self.gann = pygad.gann.GANN(num_solutions=n_solutions,
                                    num_neurons_input=n_input,
                                    num_neurons_output=n_output,
                                    num_neurons_hidden_layers=hiddens,
                                    hidden_activations=["relu", "relu"],
                                    output_activation="softmax")
        global ginst
        ginst.gann = self.gann
        ginst.gene = 0
        ginst.n_gen = n_gen
        if not hasattr(ginst, 'env') or ginst.env is None:
            ginst.env = SnakeEnv(need_render=True, use_simple=True)
            ginst.env.reset()
            ginst.env.render()
            input()
        ginst.env.reset()
        ginst.env.render()

        self.init_generation(n_gen=n_gen)
        self.n_gen = n_gen
示例#3
0
def show_gen(idx):
    global ginst

    genaction = GenAction(idx)
    env = SnakeEnv(use_simple=True)
    obs = env.reset()

    c = 0
    while True:
        c += 1
        action = genaction(obs)
        _, obs, done, _ = env(action)
        extra = '代数: {}'.format(ginst.n_gen)
        if c % 3 == 0:

            if ginst.n_gen:
                env.render(extra)
            else:
                env.render()

        # sleep(1 / ginst.n_gen)
        print(done, env.status.direction, env.life)
        if done:
            break

    sleep(1.5)
示例#4
0
文件: play.py 项目: kingyiusuen/snake
def run(display, retrain, num_episodes):
    pygame.init()
    env = SnakeEnv()
    agent = QlearningAgent(env)
    if not retrain:
        try:
            load_policy(agent)
        except:
            pass

    for _ in tqdm(range(num_episodes)):
        state = env.reset()
        done = False
        while not done:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    save_policy(agent)
                    pygame.quit()
                    sys.exit()
            if display:
                env.render()
            action = agent.act(state)
            next_state, reward, done = env.step(action)
            agent.update_q_value(state, reward, action, next_state, done)
            state = next_state
        agent.epsilon = max(agent.epsilon * agent.epsilon_decay_rate,
                            agent.min_epsilon)
    save_policy(agent)
示例#5
0
def show_gen(idx):
    global ginst

    genaction = GenAction(idx)
    env = SnakeEnv(use_simple=True)
    obs = env.reset()

    while True:
        action = genaction(obs)
        _, obs, done, _ = env(action)
        env.render()
        if done:
            break

    env.close()
示例#6
0
def test_easy():
    sum_opt = 0
    sum_0 = 0
    sum_1 = 0
    env = SnakeEnv(0,[3,6])

    policy_ref = [1]*97+[0]*3
    policy_0 = [0]*100
    policy_1 = [1]*100

    for i in range(1000):
        sum_opt += eval_game(env,policy_ref)
        sum_0 += eval_game(env,policy_0)
        sum_1 += eval_game(env,policy_1)
    print('opt %f'%(sum_opt/1000))
    print('0 %f'%(sum_0/1000))
    print('1 %f'%(sum_1/1000))
示例#7
0
def run():
    env = tf_py_environment.TFPyEnvironment(SnakeEnv(step_limit=1000))

    ## Needs to be the same network from training
    q_net = q_network.QNetwork(
        env.observation_spec(),
        env.action_spec(),
        conv_layer_params=(),
        fc_layer_params=(256, 100),
    )

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3)
    global_counter = tf.compat.v1.train.get_or_create_global_step()

    agent = dqn_agent.DqnAgent(
        env.time_step_spec(),
        env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_counter,
        gamma=0.95,
        epsilon_greedy=0.1,
        n_step_update=1,
    )

    agent.initialize()

    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=agent.policy,
        global_step=global_counter,
    )

    policy_checkpointer.initialize_or_restore()

    capture_run(
        os.path.join(root_dir, "snake" + str(global_counter.numpy()) + ".mp4"),
        env, agent.policy)
def run_snake():
    brain = DeepQNetwork(4, "")
    snakeGame = SnakeEnv()
    #先给一个向右走的决策输入,启动游戏
    observation, reward, terminal,score =snakeGame.step(np.array([0, 0, 0, 1]))
    observation = pre_process(observation)
    brain.set_init_state(observation[:,:,0])

    #开始正式游戏
    i = 1  # 步数
    while i<=500000:
        i = i + 1
        action = brain.choose_action()
        next_observation, reward, terminal, score = snakeGame.step(action)
        # print(reward)
        
        next_observation = pre_process(next_observation)
        brain.learn(next_observation, action, reward, terminal)
        if(i%100) == 0:
            print(i)
    
    # 画loss和round step的曲线
    brain.plot_cost()
    snakeGame.plot_cost()
class Application:
    def __init__(self, args):
        self.args = args
        self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
        self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
        
    def execute(self):
        if not self.args.human:
            if self.args.train_eps != 0:
                self.train()
            self.eval()
        self.show_games()

    def train(self):
        print("Train Phase:")
        self.agent.train()
        window = self.args.window
        self.points_results = []
        first_eat = True
        start = time.time()

        for game in range(1, self.args.train_eps + 1):
            state = self.env.get_state()
            dead = False
            action = self.agent.choose_action(state, 0, dead)
            while not dead:
                state, points, dead = self.env.step(action)

                # For debug convenience, you can check if your Q-table mathches ours for given setting of parameters
                # (see Debug Convenience part on homework 4 web page)
                if first_eat and points == 1:
                    self.agent.save_model(utils.CHECKPOINT)
                    first_eat = False

                action = self.agent.choose_action(state, points, dead)

    
            points = self.env.get_points()
            self.points_results.append(points)
            if game % self.args.window == 0:
                print(
                    "Games:", len(self.points_results) - window, "-", len(self.points_results), 
                    "Points (Average:", sum(self.points_results[-window:])/window,
                    "Max:", max(self.points_results[-window:]),
                    "Min:", min(self.points_results[-window:]),")",
                )
            self.env.reset()
        print("Training takes", time.time() - start, "seconds")
        self.agent.save_model(self.args.model_name)

    def eval(self):
        print("Evaling Phase:")
        self.agent.eval()
        self.agent.load_model(self.args.model_name)
        points_results = []
        start = time.time()

        for game in range(1, self.args.test_eps + 1):
            state = self.env.get_state()
            dead = False
            action = self.agent.choose_action(state, 0, dead)
            while not dead:
                state, points, dead = self.env.step(action)
                action = self.agent.choose_action(state, points, dead)
            points = self.env.get_points()
            points_results.append(points)
            self.env.reset()

        print("Testing takes", time.time() - start, "seconds")
        print("Number of Games:", len(points_results))
        print("Average Points:", sum(points_results)/len(points_results))
        print("Max Points:", max(points_results))
        print("Min Points:", min(points_results))

    def show_games(self):
        print("Display Games")
        self.env.display()
        pygame.event.pump()
        self.agent.eval()
        points_results = []
        end = False
        for game in range(1, self.args.show_eps + 1):
            state = self.env.get_state()
            dead = False
            action = self.agent.choose_action(state, 0, dead)
            count = 0
            while not dead:
                count +=1
                pygame.event.pump()
                keys = pygame.key.get_pressed()
                if keys[K_ESCAPE] or self.check_quit():
                    end = True
                    break
                state, points, dead = self.env.step(action)
                # Qlearning agent
                if not self.args.human:
                    action = self.agent.choose_action(state, points, dead)
                # for human player
                else:
                    for event in pygame.event.get():
                        if event.type == pygame.KEYDOWN:
                            if event.key == pygame.K_UP:
                                action = 2
                            elif event.key == pygame.K_DOWN:
                                action = 3
                            elif event.key == pygame.K_LEFT:
                                action = 1
                            elif event.key == pygame.K_RIGHT:
                                action = 0
            if end:
                break
            self.env.reset()
            points_results.append(points)
            print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points)
        if len(points_results) == 0:
            return
        print("Average Points:", sum(points_results)/len(points_results))

    def check_quit(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                return True
        return False
示例#10
0
def run():
    tf_env = tf_py_environment.TFPyEnvironment(SnakeEnv())
    eval_env = tf_py_environment.TFPyEnvironment(SnakeEnv(step_limit=50))

    q_net = q_network.QNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        conv_layer_params=(),
        fc_layer_params=(512, 256, 128),
    )

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    global_counter = tf.compat.v1.train.get_or_create_global_step()

    agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_counter,
        gamma=0.95,
        epsilon_greedy=0.1,
        n_step_update=1,
    )

    root_dir = os.path.join('/tf-logs', 'snake')
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_max_length,
    )

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        agent.collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration,
    )

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=agent,
        global_step=global_counter,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
    )

    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=agent.policy,
        global_step=global_counter,
    )

    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer,
    )

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    collect_driver.run = common.function(collect_driver.run)
    agent.train = common.function(agent.train)

    random_policy = random_tf_policy.RandomTFPolicy(tf_env.time_step_spec(),
                                                    tf_env.action_spec())

    if replay_buffer.num_frames() >= initial_collect_steps:
        logging.info("We loaded memories, not doing random seed")
    else:
        logging.info("Capturing %d steps to seed with random memories",
                     initial_collect_steps)

        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            random_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

    train_summary_writer = tf.summary.create_file_writer(train_dir)
    train_summary_writer.set_as_default()

    avg_returns = []
    avg_return_metric = tf_metrics.AverageReturnMetric(
        buffer_size=num_eval_episodes)
    eval_metrics = [
        avg_return_metric,
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]
    logging.info("Running initial evaluation")
    results = metric_utils.eager_compute(
        eval_metrics,
        eval_env,
        agent.policy,
        num_episodes=num_eval_episodes,
        train_step=global_counter,
        summary_writer=tf.summary.create_file_writer(eval_dir),
        summary_prefix='Metrics',
    )
    avg_returns.append(
        (global_counter.numpy(), avg_return_metric.result().numpy()))
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_counter.numpy()
    time_acc = 0

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)

    iterator = iter(dataset)

    @common.function
    def train_step():
        experience, _ = next(iterator)
        return agent.train(experience)

    for _ in range(num_iterations):
        start_time = time.time()
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )

        for _ in range(train_steps_per_iteration):
            train_loss = train_step()
        time_acc += time.time() - start_time

        step = global_counter.numpy()

        if step % log_interval == 0:
            logging.info("step = %d, loss = %f", step, train_loss.loss)
            steps_per_sec = (step - timed_at_step) / time_acc
            logging.info("%.3f steps/sec", steps_per_sec)
            timed_at_step = step
            time_acc = 0

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_counter,
                                      step_metrics=train_metrics[:2])

        if step % train_checkpoint_interval == 0:
            train_checkpointer.save(global_step=step)

        if step % policy_checkpoint_interval == 0:
            policy_checkpointer.save(global_step=step)

        if step % rb_checkpoint_interval == 0:
            rb_checkpointer.save(global_step=step)

        if step % capture_interval == 0:
            print("Capturing run:")
            capture_run(os.path.join(root_dir, "snake" + str(step) + ".mp4"),
                        eval_env, agent.policy)

        if step % eval_interval == 0:
            print("EVALUTION TIME:")
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_env,
                agent.policy,
                num_episodes=num_eval_episodes,
                train_step=global_counter,
                summary_writer=tf.summary.create_file_writer(eval_dir),
                summary_prefix='Metrics',
            )
            metric_utils.log_metrics(eval_metrics)
            avg_returns.append(
                (global_counter.numpy(), avg_return_metric.result().numpy()))
示例#11
0
    state = env.reset()
    return_val = 0
    while True:
        act = agent.play(state)
        state, reward, terminate, _ = env.step(act)
        return_val += reward
        if terminate:
            break
    return return_val


def test_agent(env, agent):
    agent.learn()
    print ('states->actions pi = {}'.format(agent.get_pi()))
    sum_reward = 0
    for i in range(100):
        sum_reward += eval_policy(env, agent)
    print ("avg reward = {}".format(sum_reward / 100.))


if __name__ == '__main__':
    env = SnakeEnv(0, [3, 6])
    # test_agent(env, TableAgent(env, PolicyIter(-1)))
    # test_agent(env, TableAgent(env, ValueIter(-1)))
    # test_agent(env, TableAgent(env, GeneralizedIter(1, 10)))
    # test_agent(env, ModelFreeAgent(env, MonteCarlo(0.1)))
    # test_agent(env, ModelFreeAgent(env, Sarsa(0.1)))
    # test_agent(env, ModelFreeAgent(env, QLearning(0.1)))
    # test_agent(env, DQNAgent(env)) # SB玩意不收敛
    test_agent(env, OpenAiAgent(env))
示例#12
0
def set_up_env(io):
    env_py = SnakeEnv(io)
    train_env = tf_py_environment.TFPyEnvironment(env_py)
    eval_env = tf_py_environment.TFPyEnvironment(env_py)
    return train_env, eval_env
 def __init__(self, args):
     self.args = args
     self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y)
     self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma)
示例#14
0
        tf.keras.layers.Dense(ACTION_DIM, activation="linear")
    ])
    model.compile(loss='mean_squared_error',
                  optimizer=tf.keras.optimizers.Adam(0.001))

    if weights:
        model.load_weights(weights)
    return model


def act(state, epsilon=0.1, step=0):
    """预测动作"""
    return np.argmax(model.predict(np.array([state]))[0])


env = SnakeEnv()
model = create_model("weights.hdf5")
for i in range(1000):
    state = env.reset()
    step = 0
    while True:
        env.render()
        model.predict(np.array([state]))

        done = True
        #next_state, reward, done, _ = env.step(0)
        step += 1
        if done:
            print('Game', i + 1, '      Score:', env.score)
            break
env.close()
示例#15
0
        s_batch = np.array([replay[0] for replay in replay_batch])
        next_s_batch = np.array([replay[2] for replay in replay_batch])

        Q = self.model.predict(s_batch)
        Q_next = self.target_model.predict(next_s_batch)

        # 使用公式更新训练集中的Q值
        for i, replay in enumerate(replay_batch):
            _, a, _, reward = replay
            Q[i][a] = (1 - lr) * Q[i][a] + lr * (reward + factor * np.amax(Q_next[i]))

        # 传入网络进行训练
        self.model.fit(s_batch, Q, verbose=0)


env = SnakeEnv()
episodes = 1000 # 训练次数

agent = DQN()
for i in range(episodes):
    state = env.reset()
    while True:
        # env.render(speed=0)
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action)
        agent.remember(state, action, next_state, reward)
        agent.train()
        state = next_state
        if done:
            print('Game', i + 1, '      Score:', env.score)
            break
示例#16
0
import sys
from snake_env import SnakeEnv
import numpy as np
from DQNetwork import DQNetwork
from Results import Results, test
from Options import Options
import pickle
import time

opt = Options().parse()

nrow, ncol = opt.gridsize, opt.gridsize
n_channels = opt.n_ch

if n_channels == 1:
    env = SnakeEnv(nrow, ncol, colors='gray')
elif n_channels == 3:
    env = SnakeEnv(nrow, ncol, colors='rgb')

n_train = opt.n_train
n_episodes = opt.n_episodes
n_batch = opt.n_batch
imax = opt.imax
min_epsilon = opt.min_epsilon
N_memory = opt.n_memory

model = DQNetwork(4, (n_channels, nrow, ncol), conv=opt.conv)

res = Results()

loadModel = opt.load
示例#17
0
def play(record=0, no_render=False):
    env = SnakeEnv(need_render=not no_render, alg='最短路径')
    obs = env.reset()
    env.render()
    input()
    x, y = [], []

    directions = {
        (-1, 0): env.right,
        (1, 0): env.left,
        (0, -1): env.down,
        (0, 1): env.up
    }

    need_record = True if record else False
    new_dst = None
    origin_dst = None
    # counter = 20
    use_random = False
    while True:
        if not record and not no_render:
            env.render()
        src = np.where(obs == 2)
        src = int(src[1]), int(src[0])
        dst = np.where(obs == -1)
        dst = int(dst[0]), int(dst[1])

        if new_dst is not None:
            paths = bfs(obs, start=src, dst=new_dst)
        else:
            paths = bfs(obs, start=src, dst=dst)

        if paths is None:
            # origin_dst = dst
            # new_dst = (
            #     np.random.randint(0, obs.shape[0]),
            #     np.random.randint(0, obs.shape[1]),
            # )
            # counter -= 1
            # if counter <= 0:
            #     print('score: ', env.status.score)
            #     new_dst = None
            #     origin_dst = None
            #     counter = 20
            #     obs = env.reset()
            # continue
            use_random = True
        else:
            new_dst = None
            if new_dst is not None and paths[1] == new_dst:
                new_dst = None
                if origin_dst is not None:
                    dst = origin_dst
                    origin_dst = None
                    # counter = 20
                    continue

        # if counter <= 0 or paths is None or len(paths) <= 1:
        #     print('score: ', env.status.score)
        #     obs = env.reset()
        #     continue

        if use_random:
            action = np.random.randint(0, 4)
            use_random = False
        else:
            dst = paths[1]
            dire = src[0] - dst[0], src[1] - dst[1]
            action = directions[dire]
        # import ipdb
        # ipdb.set_trace()
        if need_record:
            x.append(obs)
            y.append(action)
            if len(y) >= record:
                return x, y

            if len(y) % 1000 == 0:
                print(len(y))

        _, obs, done, _ = env(action)
        # counter = 20

        if done:
            print(env.status.score)
            sleep(1.5)
            break
    if not record and not no_render:
        env.render()

    env.close()
示例#18
0
def draw_graph():
    print(GRID_HEIGHT_NUM, GRID_WIDTH_NUM)
    # input()
    graph = build_graph(row=GRID_HEIGHT_NUM, col=GRID_WIDTH_NUM)
    total_graph = deepcopy(graph)
    env = SnakeEnv(set_life=100000, alg='HC + BFS', no_sight_disp=True)
    env.reset()
    sleep(1)

    graph, flags = deletion(graph, env)
    for sp in graph:
        for ep in graph[sp]:
            if flags[(sp, ep)]:
                # print(sp, ep)
                env.draw_connection(sp, ep, width=4)
            # env.render()

    import pygame
    pygame.display.update()
    pre_len = None
    while True:
        sd_len = destroy(graph, total_graph, flags, env=env)
        print('sd: ', sd_len)
        if pre_len is not None and pre_len == sd_len:
            global MAX_DEPTH
            print('+1')
            MAX_DEPTH += 1

        pre_len = sd_len

        show_graph(graph, flags, env)
        if not sd_len:
            break

    sleep(1)

    show_graph(graph, flags, env)
    counter = 0
    while not connector(graph, total_graph, flags, env):
        counter += 1
        print('counter: ', counter)

    sleep(1)

    for sp in graph:
        for ep in graph[sp]:
            if flags[(sp, ep)]:
                env.draw_connection(sp, ep, color=(0xff, 0xff, 0), width=4)

    import pygame
    show_graph(graph, flags, env)
    circle = get_list_circle(graph)
    print(circle)
    pos_encoder = {pos: i for i, pos in enumerate(circle)}
    # pos_decoder = {i: pos for i, pos in enumerate(circle)}
    pos_xy_decoder = {
        i: (pos % GRID_WIDTH_NUM, pos // GRID_WIDTH_NUM)
        for i, pos in enumerate(circle)
    }
    pos_xy_encoder = {(pos % GRID_WIDTH_NUM, pos // GRID_WIDTH_NUM): i
                      for i, pos in enumerate(circle)}
    obs = env.reset()
    c = 0
    while True:
        c += 1

        if len(env.status.snake_body) < 15:
            remainder = 20
        elif len(env.status.snake_body) < 30:
            remainder = 20
        elif len(env.status.snake_body) < 60:
            remainder = 30
        elif len(env.status.snake_body) < 90:
            remainder = 30
        elif len(env.status.snake_body) < 120:
            remainder = 40
        elif len(env.status.snake_body) < 150:
            remainder = 80
        elif len(env.status.snake_body) < 300:
            remainder = 100
        elif len(env.status.snake_body) < (GRID_WIDTH_NUM * GRID_HEIGHT_NUM -
                                           10):
            remainder = 30
        else:
            remainder = 5
        bfs_action, dst = dfs_policy(obs, env)
        bfs_dst_idx = 100000000
        if dst:
            bfs_dst_idx = pos_xy_encoder[dst]
        head = env.status.snake_body[0]
        head_pos, tail_pos = pos_xy_encoder[head], pos_xy_encoder[
            env.status.snake_body[-1]]
        head_idx, tail_idx = pos_xy_encoder[head], pos_xy_encoder[
            env.status.snake_body[-1]]

        hc_next_pos = pos_xy_decoder[(head_pos + 1) % len(graph)]

        directions = {
            (-1, 0): env.right,
            (1, 0): env.left,
            (0, -1): env.down,
            (0, 1): env.up
        }
        dire = head[0] - hc_next_pos[0], head[1] - hc_next_pos[1]
        print(head, hc_next_pos, dst, dst not in env.status.snake_body[:-1])
        print(head_idx, tail_idx, bfs_dst_idx)
        action = directions[dire]
        if not env.status.food_pos:
            show_graph(graph,
                       flags,
                       env,
                       update=True,
                       width=1,
                       extra='倍速: {} X'.format(remainder * 5))
            break

        food_idx = pos_xy_encoder[env.status.food_pos]
        if bfs_action:
            print(food_idx, bfs_dst_idx, head_idx, tail_idx)
            print(rel_pos(food_idx, tail_idx, len(graph)),
                  rel_pos(bfs_dst_idx, tail_idx, len(graph)),
                  rel_pos(head_idx, tail_idx, len(graph)),
                  rel_pos(tail_idx, tail_idx, len(graph)))
            if rel_pos(food_idx, tail_idx, len(graph)) >= rel_pos(
                    bfs_dst_idx, tail_idx, len(graph)) >= rel_pos(
                        head_idx, tail_idx, len(graph)) >= rel_pos(
                            tail_idx, tail_idx, len(graph)):
                action = bfs_action
                pass

        reward, obs, done, _ = env(action)
        if done:
            show_graph(graph,
                       flags,
                       env,
                       update=True,
                       width=1,
                       extra='倍速: {} X'.format(remainder * 5))
            print(done)
            break
        # env.screen.blit(env.background, (0, 0))

        if c % remainder == 0:
            show_graph(graph,
                       flags,
                       env,
                       update=True,
                       width=1,
                       extra='倍速: {} X'.format(remainder * 5))
        # env.render(blit=False)

    show_graph(graph, flags, env, update=True, width=1)
    sleep(10)
    input()
示例#19
0
import argparse
import gym
from stable_baselines.common.policies import CnnPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

from snake_env import SnakeEnv

parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'test'], default='test')
args = parser.parse_args()

env = SnakeEnv((20, 20), 'standard')
env = DummyVecEnv([lambda: env])
model = PPO2(CnnPolicy, env, verbose=1)

if args.mode == 'train':

    model.learn(total_timesteps=20000)
    model.save('policy_baseline_snake')

elif args.mode == 'test':

    obs = env.reset()
    model.load('policy_baseline_snake')

    for i in range(1000):
        action, _states = model.predict(obs)
        obs, reward, done, info = env.step(action)
        env.render()
        if done:
示例#20
0
# -*- coding: utf-8 -*-

import time
from snake_env import SnakeEnv

from simple_mlp import Policy


def play(env, policy):
    obs = env.reset()
    while True:
        action = policy.predict(obs)
        reward, obs, done, _ = env(action)
        env.render()
        if done:
            obs = env.reset()
            time.sleep(1)
        # time.sleep(0.05)


if __name__ == '__main__':
    policy = Policy(pre_trained='pretrained/mlp-v0.joblib')
    env = SnakeEnv(alg='MLP')
    env.reset()
    env.render()
    input()
    play(env, policy)