def run(display, retrain, num_episodes): pygame.init() env = SnakeEnv() agent = QlearningAgent(env) if not retrain: try: load_policy(agent) except: pass for _ in tqdm(range(num_episodes)): state = env.reset() done = False while not done: for event in pygame.event.get(): if event.type == pygame.QUIT: save_policy(agent) pygame.quit() sys.exit() if display: env.render() action = agent.act(state) next_state, reward, done = env.step(action) agent.update_q_value(state, reward, action, next_state, done) state = next_state agent.epsilon = max(agent.epsilon * agent.epsilon_decay_rate, agent.min_epsilon) save_policy(agent)
def run_snake(): brain = DeepQNetwork(4, "") snakeGame = SnakeEnv() #先给一个向右走的决策输入,启动游戏 observation, reward, terminal,score =snakeGame.step(np.array([0, 0, 0, 1])) observation = pre_process(observation) brain.set_init_state(observation[:,:,0]) #开始正式游戏 i = 1 # 步数 while i<=500000: i = i + 1 action = brain.choose_action() next_observation, reward, terminal, score = snakeGame.step(action) # print(reward) next_observation = pre_process(next_observation) brain.learn(next_observation, action, reward, terminal) if(i%100) == 0: print(i) # 画loss和round step的曲线 brain.plot_cost() snakeGame.plot_cost()
from stable_baselines.common.vec_env import DummyVecEnv from stable_baselines import PPO2 from snake_env import SnakeEnv parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train', 'test'], default='test') args = parser.parse_args() env = SnakeEnv((20, 20), 'standard') env = DummyVecEnv([lambda: env]) model = PPO2(CnnPolicy, env, verbose=1) if args.mode == 'train': model.learn(total_timesteps=20000) model.save('policy_baseline_snake') elif args.mode == 'test': obs = env.reset() model.load('policy_baseline_snake') for i in range(1000): action, _states = model.predict(obs) obs, reward, done, info = env.step(action) env.render() if done: env.reset() env.close()
class Application: def __init__(self, args): self.args = args self.env = SnakeEnv(args.snake_head_x, args.snake_head_y, args.food_x, args.food_y) self.agent = Agent(self.env.get_actions(), args.Ne, args.C, args.gamma) def execute(self): if not self.args.human: if self.args.train_eps != 0: self.train() self.eval() self.show_games() def train(self): print("Train Phase:") self.agent.train() window = self.args.window self.points_results = [] first_eat = True start = time.time() for game in range(1, self.args.train_eps + 1): state = self.env.get_state() dead = False action = self.agent.choose_action(state, 0, dead) while not dead: state, points, dead = self.env.step(action) # For debug convenience, you can check if your Q-table mathches ours for given setting of parameters # (see Debug Convenience part on homework 4 web page) if first_eat and points == 1: self.agent.save_model(utils.CHECKPOINT) first_eat = False action = self.agent.choose_action(state, points, dead) points = self.env.get_points() self.points_results.append(points) if game % self.args.window == 0: print( "Games:", len(self.points_results) - window, "-", len(self.points_results), "Points (Average:", sum(self.points_results[-window:])/window, "Max:", max(self.points_results[-window:]), "Min:", min(self.points_results[-window:]),")", ) self.env.reset() print("Training takes", time.time() - start, "seconds") self.agent.save_model(self.args.model_name) def eval(self): print("Evaling Phase:") self.agent.eval() self.agent.load_model(self.args.model_name) points_results = [] start = time.time() for game in range(1, self.args.test_eps + 1): state = self.env.get_state() dead = False action = self.agent.choose_action(state, 0, dead) while not dead: state, points, dead = self.env.step(action) action = self.agent.choose_action(state, points, dead) points = self.env.get_points() points_results.append(points) self.env.reset() print("Testing takes", time.time() - start, "seconds") print("Number of Games:", len(points_results)) print("Average Points:", sum(points_results)/len(points_results)) print("Max Points:", max(points_results)) print("Min Points:", min(points_results)) def show_games(self): print("Display Games") self.env.display() pygame.event.pump() self.agent.eval() points_results = [] end = False for game in range(1, self.args.show_eps + 1): state = self.env.get_state() dead = False action = self.agent.choose_action(state, 0, dead) count = 0 while not dead: count +=1 pygame.event.pump() keys = pygame.key.get_pressed() if keys[K_ESCAPE] or self.check_quit(): end = True break state, points, dead = self.env.step(action) # Qlearning agent if not self.args.human: action = self.agent.choose_action(state, points, dead) # for human player else: for event in pygame.event.get(): if event.type == pygame.KEYDOWN: if event.key == pygame.K_UP: action = 2 elif event.key == pygame.K_DOWN: action = 3 elif event.key == pygame.K_LEFT: action = 1 elif event.key == pygame.K_RIGHT: action = 0 if end: break self.env.reset() points_results.append(points) print("Game:", str(game)+"/"+str(self.args.show_eps), "Points:", points) if len(points_results) == 0: return print("Average Points:", sum(points_results)/len(points_results)) def check_quit(self): for event in pygame.event.get(): if event.type == pygame.QUIT: return True return False
# 使用公式更新训练集中的Q值 for i, replay in enumerate(replay_batch): _, a, _, reward = replay Q[i][a] = (1 - lr) * Q[i][a] + lr * (reward + factor * np.amax(Q_next[i])) # 传入网络进行训练 self.model.fit(s_batch, Q, verbose=0) env = SnakeEnv() episodes = 1000 # 训练次数 agent = DQN() for i in range(episodes): state = env.reset() while True: # env.render(speed=0) action = agent.act(state) next_state, reward, done, _ = env.step(action) agent.remember(state, action, next_state, reward) agent.train() state = next_state if done: print('Game', i + 1, ' Score:', env.score) break if (i+1) % 10 == 0: agent.save_model() agent.save_model() env.close()
lengths_i_train = [] scores_i_train = [] for i_episode in range(n_episodes): i = 0 done = False grid = env.reset() grid = grid.reshape((1, n_channels, env.nrow, env.ncol)) t0 = time.time() while i < imax: i += 1 source = grid.copy() if epsilon >= np.random.rand(): action = np.random.randint(4) else: action = np.argmax(model.predict(source)) grid, reward, done = env.step(action) grid = grid.reshape((1, n_channels, env.nrow, env.ncol)) observation = {'source':source, 'action':action, \ 'dest':grid, 'reward':reward,'final':done} res.memory.append(observation) if done: break lengths_i_train.append(i) scores_i_train.append(env.score) t1 = time.time() res.lengths.append(lengths_i_train) res.scores.append(scores_i_train)