def render_episode(env, estimator: DoubleDQN): state = env.reset() is_done = False while not is_done: action = torch.argmax(estimator.predict(state)).item() state, reward, is_done, info = env.step(action) env.render() env.close()
def run_episode(env, estimator: DoubleDQN): state = env.reset() is_done = False total_reward = 0 while not is_done: action = torch.argmax(estimator.predict(state)).item() state, reward, is_done, info = env.step(action) total_reward += reward return total_reward
def main(argv=None): if argv is None: argv = sys.argv try: try: opts, args = getopt.getopt(argv[1:], "hn:", ["help", "network="]) except getopt.error as msg: raise Usage(msg) except Usage as err: print(sys.stderr, err.msg) print(sys.stderr, "for help use --help") return 2 if len(opts) == 0: print("Please specify parameters!") return 1 for opt, arg in opts: if opt in ("-h", "--help"): print(__doc__) return 0 elif opt in ("-n", "--network"): if arg == 'dqn': from dqn import DeepQNetwork score_graph_path = './saved_model/' network = DeepQNetwork( e_greedy=0.1, output_graph=True, save_path=score_graph_path, ) elif arg == 'doubledqn': from double_dqn import DoubleDQN score_graph_path = './saved_model_doubledqn/' network = DoubleDQN( e_greedy=0.1, output_graph=True, save_path=score_graph_path, ) else: print( "You could choose 'dqn', 'doubledqn' as network's parameter" ) return 1 train(network, score_graph_path) return 0
def q_learning(env, estimator: DoubleDQN, n_episode, target_update_every = 10, gamma = 1.0, epsilon = 0.1, epsilon_decay = 0.99, replay_size = 20): for episode in range(n_episode): policy = gen_epsilon_greedy_policy(estimator, epsilon, n_action) state = env.reset() is_done = False episode_memory = [] while not is_done: action = policy(state) next_state, reward, is_done, _ = env.step(action) modified_reward = modify_reward(next_state[0]) total_reward_episode[episode] += reward total_mod_reward_episode[episode] += modified_reward memory.append((state, action, next_state, modified_reward, is_done)) episode_memory.append((state, action, next_state, modified_reward, is_done)) if is_done: if total_reward_episode[episode] > -200: # nice episode to replay :) total_memory.append(episode_memory) break estimator.replay(memory, replay_size, gamma) state = next_state print('Episode {}: reward: {}, epsilon: {}'.format( episode, total_reward_episode[episode], epsilon )) epsilon = max(epsilon * epsilon_decay, 0.01) if (episode % target_update_every) == 1: # update targets NN estimator.copy_target() # check if NN is well trained total_wins = 0 for test in range(100): total_wins += 1 if run_episode(env, estimator) > -200 else 0 if (total_wins > 1): break if (episode % 100) == 1: for good_ep_mem in total_memory: estimator.replay(good_ep_mem, len(good_ep_mem), gamma) render_episode(env, estimator)
def play(self): while True: for event in pygame.event.get(): # 退出事件 if event.type == gloc.QUIT: pygame.quit() sys.exit() # 键盘事件 elif event.type == gloc.KEYDOWN: # 空格/上键 if event.key == gloc.K_SPACE or event.key == gloc.K_UP: # 游戏界面,小鸟存活,未暂停 # ----> 游戏开始/小鸟拍翅膀 if (not self.start and not self.ranking and not self.setting and not self.paused and self.bird.alive): self.pressed = True # 限制小鸟高度 if self.bird.rect.top > -2 * self.bird.rect.height: self.bird.fly() self.sound['wing_sound'].play() # P键/Esc键 elif event.key == gloc.K_p or event.key == gloc.K_ESCAPE: # 游戏界面,小鸟存活,未暂停 # ----> 游戏暂停/开始 if (not self.start and not self.ranking and not self.setting and self.pressed and self.bird.alive): self.paused = not self.paused # G键 elif event.key == gloc.K_g: if self.start and not hasattr(self, "ai_model"): self.init_vars(ai=True) self.ai_model = DoubleDQN() # 鼠标移动事件 elif event.type == gloc.MOUSEMOTION: # 设置界面 if self.setting and self.mouse_down: pos = pygame.mouse.get_pos() # RGB设置 # 身体 if pygame.Rect(64, 195, 40, 11).collidepoint(pos): self.body_rgb[0] = (pos[0] - 64) * 255 / 40 self.R1_set.set_point(self.body_rgb[0] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(125, 195, 40, 11).collidepoint(pos): self.body_rgb[1] = (pos[0] - 125) * 255 / 40 self.G1_set.set_point(self.body_rgb[1] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(189, 195, 40, 11).collidepoint(pos): self.body_rgb[2] = (pos[0] - 189) * 255 / 40 self.B1_set.set_point(self.body_rgb[2] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) # 嘴 elif pygame.Rect(64, 245, 40, 11).collidepoint(pos): self.mouth_rgb[0] = (pos[0] - 64) * 255 / 40 self.R2_set.set_point(self.mouth_rgb[0] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(125, 245, 40, 11).collidepoint(pos): self.mouth_rgb[1] = (pos[0] - 125) * 255 / 40 self.G2_set.set_point(self.mouth_rgb[1] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(189, 245, 40, 11).collidepoint(pos): self.mouth_rgb[2] = (pos[0] - 189) * 255 / 40 self.B2_set.set_point(self.mouth_rgb[2] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) # 音量设置 elif pygame.Rect(105, 352, 110, 15).collidepoint(pos): self.volume = (pos[0] - 105) * 100 / 110 self.volume_set.set_point(self.volume / 100) pygame.mixer.music.set_volume(self.volume * 0.4 / 100) elif pygame.Rect(105, 402, 110, 15).collidepoint(pos): self.sound_volume = (pos[0] - 105) * 100 / 110 self.sound_set.set_point(self.sound_volume / 100) for i in self.sound.keys(): self.sound[i].set_volume( self.sound_volume * self.sound_default[i] / 100) # 移出区域视为设置结束 else: self.mouse_down = False # 鼠标点击释放 elif event.type == gloc.MOUSEBUTTONUP: # 设置界面 if self.setting and self.mouse_down: self.mouse_down = False # 鼠标点击事件 elif event.type == gloc.MOUSEBUTTONDOWN: pos = event.pos # 鼠标左键 if event.button == 1: # 开始界面 if self.start: # 进入游戏界面 if self.start_image_rect.collidepoint(pos): self.start = False # 进入排行界面 elif self.score_image_rect.collidepoint(pos): self.start = False self.ranking = True # 进入设置界面 elif self.setting_image_rect.collidepoint(pos): self.start = False self.setting = True # 排行榜界面 elif self.ranking: # 回到开始界面 if self.back_rect.collidepoint(pos): self.ranking = False self.start = True # 设置界面 elif self.setting: # 回到开始界面 if self.setting_image_rect.collidepoint(pos): self.start = True self.setting = False setting.write_json(self.bird_color, self.background_index, self.volume, self.sound_volume) # 小鸟设置 elif pygame.Rect(52, 105, 30, 30)\ .collidepoint(pos): self.bird_color = (self.bird_color - 1) % 5 self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(202, 105, 30, 30)\ .collidepoint(pos): self.bird_color = (self.bird_color + 1) % 5 self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) # RGB设置 # 身体 elif pygame.Rect(64, 195, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.body_rgb[0] = (pos[0] - 64) * 255 / 40 self.R1_set.set_point(self.body_rgb[0] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(125, 195, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.body_rgb[1] = (pos[0] - 125) * 255 / 40 self.G1_set.set_point(self.body_rgb[1] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(189, 195, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.body_rgb[2] = (pos[0] - 189) * 255 / 40 self.B1_set.set_point(self.body_rgb[2] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) # 嘴 elif pygame.Rect(64, 245, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.mouth_rgb[0] = (pos[0] - 64) * 255 / 40 self.R2_set.set_point(self.mouth_rgb[0] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(125, 245, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.mouth_rgb[1] = (pos[0] - 125) * 255 / 40 self.G2_set.set_point(self.mouth_rgb[1] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(189, 245, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.mouth_rgb[2] = (pos[0] - 189) * 255 / 40 self.B2_set.set_point(self.mouth_rgb[2] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) # 背景设置 elif pygame.Rect(100, 292, 30, 30)\ .collidepoint(pos): self.background_index = ( self.background_index - 1) % 3 if self.background_index != 2: self.background = self.background_list[ self.background_index] elif pygame.Rect(200, 292, 30, 30)\ .collidepoint(pos): self.background_index = ( self.background_index + 1) % 3 if self.background_index != 2: self.background = self.background_list[ self.background_index] # 音量设置 elif pygame.Rect(105, 352, 110, 15)\ .collidepoint(pos): self.mouse_down = True self.volume = (pos[0] - 105) * 100 / 110 self.volume_set.set_point(self.volume / 100) pygame.mixer.music.set_volume(self.volume * 0.4 / 100) elif pygame.Rect(105, 402, 110, 15)\ .collidepoint(pos): self.mouse_down = True self.sound_volume = (pos[0] - 105) * 100 / 110 self.sound_set.set_point(self.sound_volume / 100) for i in self.sound.keys(): self.sound[i].set_volume( self.sound_volume * self.sound_default[i] / 100) # 分享画面 elif self.share: if self.copy_rect.collidepoint(pos): try: share.copy(self.image_data) except AttributeError: pass elif self.save_rect.collidepoint(pos): share.save(self.image_data) elif self.email_rect.collidepoint(pos): share.send_email(self.image_data, self.score) elif self.back_rect.collidepoint(pos): self.share = False # 游戏界面,小鸟存活 elif (self.pressed and self.bird.alive and self.pause_image_rect.collidepoint(pos)): self.paused = not self.paused # ----> 游戏开始/小鸟拍翅膀 elif not self.paused and self.bird.alive: self.pressed = True # 限制小鸟高度 if self.bird.rect.top > -2 * self.bird.rect.height: self.bird.fly() self.sound['wing_sound'].play() # 游戏结束界面 elif not self.bird.alive: pos = pygame.mouse.get_pos() if self.retry_rect.collidepoint(pos): self.init_vars() self.start = False elif self.share_rect.collidepoint(pos): self.image_data = pygame.surfarray.array3d( pygame.display.get_surface()) self.share = True elif self.menu_rect.collidepoint(pos): self.init_vars() # 游戏基础画面 self.screen.blit(self.background, (0, 0)) # 绘制地面 self.screen.blit(self.land.image, self.land.rect) if self.bird.alive and not self.paused: self.land.move() # 游戏开始画面 if self.start: # 绘制游戏名 self.screen.blit(self.title, self.title_rect) # 绘制开始按钮 self.screen.blit(self.start_image, self.start_image_rect) # 绘制排行按钮 self.screen.blit(self.score_image, self.score_image_rect) # 绘制设置按钮 self.screen.blit(self.setting_image, self.setting_image_rect) # 设置 elif self.setting: self.screen.blit(self.board_image, self.board_rect) self.screen.blit(self.setting_image, self.setting_image_rect) # 绘制小鸟设置 self.screen.blit(self.array_left, (52, 105)) self.screen.blit(self.array_right, (202, 105)) if self.bird_color in [0, 1, 2]: self.screen.blit( self.bird.images[self.bird.image_index(self.delay)], (120, 100)) elif self.bird_color == 3: self.screen.blit( self.random_bird[self.bird.image_index(self.delay)], (120, 100)) self.screen.blit( self.random_text, ((self.width - self.random_text.get_width()) // 2, 150)) elif self.bird_color == 4: self.screen.blit( self.bird.images[self.bird.image_index(self.delay)], (120, 100)) self.screen.blit( self.custom_text, ((self.width - self.custom_text.get_width()) // 2, 150)) self.screen.blit( self.body_text, ((self.width - self.body_text.get_width()) // 2, 170)) self.body_rgb = list(self.bird.images[0].get_at((23, 24))) self.screen.blit(self.R_text, (50, 190)) self.R1_set.set_point(self.body_rgb[0] / 255) self.R1_set.display() self.screen.blit(self.G_text, (113, 190)) self.G1_set.set_point(self.body_rgb[1] / 255) self.G1_set.display() self.screen.blit(self.B_text, (175, 190)) self.B1_set.set_point(self.body_rgb[2] / 255) self.B1_set.display() self.screen.blit( self.mouth_text, ((self.width - self.mouth_text.get_width()) // 2, 220)) self.mouth_rgb = list(self.bird.images[0].get_at((30, 27))) self.screen.blit(self.R_text, (50, 240)) self.R2_set.set_point(self.mouth_rgb[0] / 255) self.R2_set.display() self.screen.blit(self.G_text, (113, 240)) self.G2_set.set_point(self.mouth_rgb[1] / 255) self.G2_set.display() self.screen.blit(self.B_text, (175, 240)) self.B2_set.set_point(self.mouth_rgb[2] / 255) self.B2_set.display() # 绘制背景设置 self.screen.blit(self.bg_text, (50, 300)) self.screen.blit(self.array_left, (100, 292)) self.screen.blit(self.array_right, (200, 292)) self.screen.blit(self.bg_text_list[self.background_index], (150, 300)) # 绘制音量设置 self.screen.blit(self.volume_text, (50, 350)) self.volume_set.display() # 绘制音效设置 self.screen.blit(self.sound_text, (50, 400)) self.sound_set.display() # 排行界面 elif self.ranking: self.screen.blit(self.board_image, self.board_rect) if self.value is None: self.value = score.Sql.get_score() for i in range(len(self.value)): self.screen.blit(self.cups[i], self.cup_rects[i]) time_tran = time.strftime("%Y/%m/%d %H:%M:%S", time.localtime( self.value[i][0])).split() score_text = self.rank_font.render(str(self.value[i][1]), True, (0, 0, 0)) time_text1 = self.setting_font.render( time_tran[0], True, (0, 0, 0)) time_text2 = self.setting_font.render( time_tran[1], True, (0, 0, 0)) self.screen.blit( score_text, (self.cup_rects[i][0] + 50, self.cup_rects[i][1] + 10)) self.screen.blit( time_text1, (self.cup_rects[i][0] + 95, self.cup_rects[i][1] + 5)) self.screen.blit(time_text2, (self.cup_rects[i][0] + 105, self.cup_rects[i][1] + 23)) self.screen.blit(self.back_image, self.back_rect) # 分享画面 elif self.share: self.screen.blit(self.board_image, self.board_rect) self.screen.blit(self.copy_image, self.copy_rect) self.screen.blit(self.save_image, self.save_rect) self.screen.blit(self.email_image, self.email_rect) self.screen.blit(self.back_image, self.back_rect) # 游戏画面 else: # 准备画面 if not self.pressed: # 绘制小鸟 self.screen.blit( self.bird.images[self.bird.image_index(self.delay)], self.bird.rect) # 绘制ready self.screen.blit(self.ready_image, self.ready_rect) # 绘制press开始 self.screen.blit(self.press_start_image, self.press_start_rect) else: # 移动小鸟 if not self.paused: self.bird.move(self.delay) if self.ai and not self.paused and self.bird.alive: self.screen.blit(self.bg_black, (0, 0)) # 绘制pipe for upipe, dpipe in zip(self.upperpipes, self.lowerpipes): self.screen.blit(upipe.image, upipe.rect) self.screen.blit(dpipe.image, dpipe.rect) # 绘制小鸟 self.screen.blit(self.bird.image, self.bird.rect) # 绘制地面 self.screen.blit(self.land.image, self.land.rect) if self.ai and not self.paused and self.bird.alive: img = pygame.surfarray.array3d( pygame.display.get_surface()) if not hasattr(self.ai_model, "currentState"): self.ai_model.currentState = \ self.ai_model.set_initial_state(img) self.ai_model.currentState = \ self.ai_model.update_state(img) _, action_index = self.ai_model.getAction() if action_index == 1: if self.bird.rect.top > -2 * self.bird.rect.height: self.bird.fly() if self.bird.alive: # 绘制分数 score.display(self.screen, self.bg_size, self.score) if not self.paused: # 绘制暂停按钮 self.screen.blit(self.pause_image, self.pause_image_rect) # 移动pipe for upipe, dpipe in zip(self.upperpipes, self.lowerpipes): upipe.move() dpipe.move() else: # 绘制继续按钮 self.screen.blit(self.resume_image, self.resume_image_rect) # 生成和删除pipe if 0 < self.upperpipes[0].rect.left < 5: new_upipe, new_dpipe = pipe.get_pipe( self.bg_size, self.land.rect.top) self.upperpipes.append(new_upipe) self.lowerpipes.append(new_dpipe) self.pipe_group.add(new_upipe, new_dpipe) if self.upperpipes[0].rect.right < 0: self.pipe_group.remove(self.upperpipes[0], self.lowerpipes[0]) self.upperpipes.pop(0) self.lowerpipes.pop(0) # 得分 if self.bird.alive: for upipe in self.upperpipes: if (upipe.rect.centerx <= self.bird.rect.centerx < upipe.rect.centerx + 4): self.score += 1 self.sound['point_sound'].play() # 检测碰撞 if self.bird.alive and self.checkCrash(): self.bird.alive = False self.sound['hit_sound'].play() self.sound['die_sound'].play() # 游戏结束画面 if not self.bird.alive: # 绘制gameover字样 self.screen.blit(self.gameover_image, self.gameover_image_rect) # 绘制成绩面板 self.screen.blit(self.score_panel, self.score_panel_rect) score.show_score(self.screen, self.bg_size, self.score) if not self.recorded and self.value is None: self.value = score.Sql.get_score() if self.value: best_score = self.value[0][1] score.show_best(self.screen, self.bg_size, best_score) # 绘制奖牌 if self.score >= 100: self.screen.blit(self.white_medal, self.medal_rect) elif self.score >= 60: self.screen.blit(self.gold_medal, self.medal_rect) elif self.score >= 30: self.screen.blit(self.silver_medal, self.medal_rect) elif self.score >= 10: self.screen.blit(self.brooze_medal, self.medal_rect) # 绘制重新开始 self.screen.blit(self.retry_image, self.retry_rect) self.screen.blit(self.share_image, self.share_rect) self.screen.blit(self.menu_image, self.menu_rect) # 保存分数 if not self.recorded: new_record = score.Sql.set_score(self.score) self.value = score.Sql.get_score() self.recorded = True if new_record: self.screen.blit(self.new_image, self.new_rect) # 画面刷新 self.delay = (self.delay + 1) % 30 pygame.display.update() self.clock.tick(30)
class Game(): def __init__(self): pygame.init() self.bg_size = self.width, self.height = 288, 512 self.screen = pygame.display.set_mode(self.bg_size) pygame.display.set_caption("Flappy Bird") icon = pygame.image.load("assets/images/flappy.ico") pygame.display.set_icon(icon) pygame.mixer.init() pygame.mixer.set_num_channels(4) self.clock = pygame.time.Clock() self.init_sound() self.init_pics() self.init_vars() # 加载声音 def init_sound(self): self.sound = {} self.sound_default = {} # 背景音乐 self.bgm = pygame.mixer.music.load("assets/sound/bgm.ogg") pygame.mixer.music.set_volume(0.4) pygame.mixer.music.play(-1) # 死亡声音 self.sound['die_sound'] = pygame.mixer.Sound("assets/sound/die.ogg") self.sound_default['die_sound'] = 0.4 # 撞击声音 self.sound['hit_sound'] = pygame.mixer.Sound("assets/sound/hit.ogg") self.sound_default['hit_sound'] = 0.4 # 得分声音 self.sound['point_sound'] = pygame.mixer.Sound( "assets/sound/point.ogg") self.sound_default['point_sound'] = 0.4 # 拍翅膀声音 self.sound['wing_sound'] = pygame.mixer.Sound("assets/sound/wing.ogg") self.sound_default['wing_sound'] = 0.8 # 加载图片 def init_pics(self): # 加载背景与地面 self.bg_black = pygame.image.load("assets/images/bg_black.png")\ .convert_alpha() self.background_list = [ pygame.image.load("assets/images/bg_day.png").convert(), pygame.image.load("assets/images/bg_night.png").convert() ] self.land = land.Land(self.bg_size) # 游戏开始画面 # 游戏标题 self.title = pygame.image.load("assets/images/start/title.png")\ .convert_alpha() self.title_rect = self.title.get_rect() self.title_rect.left = (self.width - self.title_rect.width) // 2 self.title_rect.top = 80 # 开始按钮 self.start_image = pygame.image.load("assets/images/start/start.png")\ .convert_alpha() self.start_image_rect = self.start_image.get_rect() self.start_image_rect.left = (self.width - self.start_image_rect.width) // 2 self.start_image_rect.top = 240 # 排行榜按钮 self.score_image = pygame.image.load("assets/images/start/score.png")\ .convert_alpha() self.score_image_rect = self.score_image.get_rect() self.score_image_rect.left = (self.width - self.score_image_rect.width) // 2 self.score_image_rect.top = 310 # 设置按钮 self.setting_image = pygame.image.load("assets/images/start/setting.png")\ .convert_alpha() self.setting_image_rect = self.setting_image.get_rect() self.setting_image_rect.left = (self.width - self.setting_image_rect.width - 10) self.setting_image_rect.top = 10 # 排行画面 # 奖杯 self.cups = [ pygame.image.load( "assets/images/rank/gold_cup.png").convert_alpha(), pygame.image.load( "assets/images/rank/silver_cup.png").convert_alpha(), pygame.image.load( "assets/images/rank/brooze_cup.png").convert_alpha() ] self.cup_rects = [(50, 120), (50, 200), (50, 280)] # 字体 self.rank_font = pygame.font.Font("assets/font/hanyihaiyun.ttf", 24) # 设置画面 # 设置面板 self.board_image = pygame.image.load("assets/images/board.png")\ .convert_alpha() self.board_rect = self.board_image.get_rect() self.board_rect.top = 20 self.board_rect.left = (self.width - self.board_rect.width) // 2 # 左右箭头 self.array_right = pygame.image.load("assets/images/start/array.png")\ .convert_alpha() self.array_left = pygame.transform.rotate(self.array_right, 180) # 设置字体 self.setting_font = pygame.font.Font("assets/font/hanyihaiyun.ttf", 16) # 小鸟设置 self.random_text = self.setting_font.render("随机", True, (0, 0, 0)) self.custom_text = self.setting_font.render("自定义", True, (0, 0, 0)) # 随机小鸟设置 self.random_bird = [ pygame.image.load( "assets/images/birds/random_0.png").convert_alpha(), pygame.image.load( "assets/images/birds/random_1.png").convert_alpha(), pygame.image.load( "assets/images/birds/random_2.png").convert_alpha() ] # 自定义小鸟设置 self.body_text = self.setting_font.render("身体", True, (0, 0, 0)) self.mouth_text = self.setting_font.render("嘴", True, (0, 0, 0)) self.R_text = self.setting_font.render("R", True, (0, 0, 0)) self.G_text = self.setting_font.render("G", True, (0, 0, 0)) self.B_text = self.setting_font.render("B", True, (0, 0, 0)) self.customize_bird = setting.Customize_bird() # 背景设置 self.bg_text = self.setting_font.render("背景:", True, (0, 0, 0)) self.bg_text_list = [ self.setting_font.render("白天", True, (0, 0, 0)), self.setting_font.render("夜晚", True, (0, 0, 0)), self.random_text ] # 音量设置 self.volume_text = self.setting_font.render("音量:", True, (0, 0, 0)) self.sound_text = self.setting_font.render("音效:", True, (0, 0, 0)) # 游戏画面 # 准备图片 self.ready_image = pygame.image.load("assets/images/game/ready.png")\ .convert_alpha() self.ready_rect = self.ready_image.get_rect() self.ready_rect.left = (self.width - self.ready_rect.width) // 2 self.ready_rect.top = self.height * 0.12 # 点击开始图片 self.press_start_image = pygame.image.load("assets/images/game/tutorial.png")\ .convert_alpha() self.press_start_rect = self.press_start_image.get_rect() self.press_start_rect.left = (self.width - self.press_start_rect.width) // 2 self.press_start_rect.top = self.height * 0.5 # 暂停按钮 self.pause_image = pygame.image.load("assets/images/game/pause.png")\ .convert_alpha() self.pause_image_rect = self.pause_image.get_rect() self.pause_image_rect.left = (self.width - self.pause_image_rect.width - 10) self.pause_image_rect.top = 10 # 继续按钮 self.resume_image = pygame.image.load("assets/images/game/resume.png")\ .convert_alpha() self.resume_image_rect = self.resume_image.get_rect() self.resume_image_rect.left = (self.width - self.resume_image_rect.width - 10) self.resume_image_rect.top = 10 # 分享画面 # 复制到剪贴板 self.copy_image = pygame.image.load("assets/images/share/copy.png")\ .convert_alpha() self.copy_rect = self.copy_image.get_rect() self.copy_rect.left = (self.width - self.copy_rect.width) // 2 self.copy_rect.top = 110 # 保存至本地 self.save_image = pygame.image.load("assets/images/share/save.png")\ .convert_alpha() self.save_rect = self.save_image.get_rect() self.save_rect.left = (self.width - self.save_rect.width) // 2 self.save_rect.top = 200 # 使用邮件分享 self.email_image = pygame.image.load("assets/images/share/email.png")\ .convert_alpha() self.email_rect = self.email_image.get_rect() self.email_rect.left = (self.width - self.email_rect.width) // 2 self.email_rect.top = 290 # 返回 self.back_image = pygame.image.load("assets/images/share/back.png")\ .convert_alpha() self.back_rect = self.back_image.get_rect() self.back_rect.left = (self.width - self.back_rect.width) // 2 self.back_rect.top = 380 # 游戏结束画面 # 游戏结束图片 self.gameover_image = pygame.image.load("assets/images/end/gameover.png")\ .convert_alpha() self.gameover_image_rect = self.gameover_image.get_rect() self.gameover_image_rect.left = (self.width - self.gameover_image_rect.width) // 2 self.gameover_image_rect.top = self.height * 0.12 # 得分面版 self.score_panel = pygame.image.load("assets/images/end/score_panel.png")\ .convert_alpha() self.score_panel_rect = self.score_panel.get_rect() self.score_panel_rect.left = (self.width - self.score_panel_rect.width) // 2 self.score_panel_rect.top = self.height * 0.24 # 奖牌图片 self.white_medal = pygame.image.load("assets/images/end/medal0.png")\ .convert_alpha() self.gold_medal = pygame.image.load("assets/images/end/medal1.png")\ .convert_alpha() self.silver_medal = pygame.image.load("assets/images/end/medal2.png")\ .convert_alpha() self.brooze_medal = pygame.image.load("assets/images/end/medal3.png")\ .convert_alpha() self.medal_rect = (57, 165) # 新纪录图片 self.new_image = pygame.image.load("assets/images/end/new.png")\ .convert_alpha() self.new_rect = self.new_image.get_rect() self.new_rect.left, self.new_rect.top = 150, 139 # 再来一次图片 self.retry_image = pygame.image.load("assets/images/end/retry.png")\ .convert_alpha() self.retry_rect = self.retry_image.get_rect() self.retry_rect.left = (self.width - self.retry_rect.width) // 2 self.retry_rect.top = self.height * 0.5 # 分享按钮 self.share_image = pygame.image.load("assets/images/end/share.png")\ .convert_alpha() self.share_rect = self.share_image.get_rect() self.share_rect.left = (self.width - self.share_rect.width) // 2 self.share_rect.top = self.retry_rect.top + 30 # 主菜单按钮 self.menu_image = pygame.image.load("assets/images/end/menu.png")\ .convert_alpha() self.menu_rect = self.menu_image.get_rect() self.menu_rect.left = (self.width - self.menu_rect.width) // 2 self.menu_rect.top = self.retry_rect.top + 60 # 初始化游戏数据 def init_vars(self, ai: bool = False): # 读取设置 (self.bird_color, self.background_index, self.volume, self.sound_volume) = setting.read_config() # 设置音量 pygame.mixer.music.set_volume(self.volume * 0.4 / 100) for i in self.sound.keys(): self.sound[i].set_volume(self.sound_volume * self.sound_default[i] / 100) # 游戏分数 self.score = 0 # 背景 if self.background_index == 2: pipe.PIPE_INDEX = random.choice([0, 1]) elif self.background_index in [0, 1]: pipe.PIPE_INDEX = self.background_index self.background = self.background_list[pipe.PIPE_INDEX] # 是否开挂 self.ai = ai # 游戏开始画面 self.start = True # 排行榜画面 self.ranking = False self.value = None # 设置画面 self.setting = False self.mouse_down = False self.R1_set = setting.Setting_line(self.screen, rect=(64, 199), lenth=40, point=0.5, color=(255, 0, 0), height=3) self.G1_set = setting.Setting_line(self.screen, rect=(125, 199), lenth=40, point=0.5, color=(0, 255, 0), height=3) self.B1_set = setting.Setting_line(self.screen, rect=(189, 199), lenth=40, point=0.5, color=(0, 0, 255), height=3) self.R2_set = setting.Setting_line(self.screen, rect=(64, 249), lenth=40, point=0.5, color=(255, 0, 0), height=3) self.G2_set = setting.Setting_line(self.screen, rect=(125, 249), lenth=40, point=0.5, color=(0, 255, 0), height=3) self.B2_set = setting.Setting_line(self.screen, rect=(189, 249), lenth=40, point=0.5, color=(0, 0, 255), height=3) self.volume_set = setting.Setting_line(self.screen, rect=(105, 358), lenth=110, point=self.volume / 100, color=(230, 100, 0)) self.sound_set = setting.Setting_line(self.screen, rect=(105, 408), lenth=110, point=self.sound_volume / 100, color=(230, 100, 0)) # 游戏画面 self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color, ai=ai) self.delay = 0 self.paused = False self.pressed = False self.upperpipes = [] self.lowerpipes = [] self.pipe_group = pygame.sprite.Group() if not ai: upipe, dpipe = pipe.get_pipe(self.bg_size, self.land.rect.top, self.width + 200) else: upipe, dpipe = pipe.get_pipe(self.bg_size, self.land.rect.top, self.width) self.upperpipes.append(upipe) self.lowerpipes.append(dpipe) self.pipe_group.add(upipe, dpipe) if not ai: upipe, dpipe = pipe.get_pipe(self.bg_size, self.land.rect.top, 1.5 * self.width + 200) else: upipe, dpipe = pipe.get_pipe(self.bg_size, self.land.rect.top, 1.5 * self.width) self.upperpipes.append(upipe) self.lowerpipes.append(dpipe) self.pipe_group.add(upipe, dpipe) # 游戏结束画面 self.recorded = False # 分享画面 self.share = False # 检测碰撞 def checkCrash(self): # if player crashes into ground if self.bird.rect.top + self.bird.rect.height\ >= self.land.rect.top + 1: return True playerRect = self.bird.rect for uPipe, lPipe in zip(self.upperpipes, self.lowerpipes): # upper and lower pipe rects uPipeRect = uPipe.rect lPipeRect = lPipe.rect # player and upper/lower pipe hitmasks pHitMask = self.bird.mask uHitmask = uPipe.mask lHitmask = lPipe.mask # if bird collided with upipe or lpipe uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask) lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask) if uCollide or lCollide: return True return False # 开始游戏 def play(self): while True: for event in pygame.event.get(): # 退出事件 if event.type == gloc.QUIT: pygame.quit() sys.exit() # 键盘事件 elif event.type == gloc.KEYDOWN: # 空格/上键 if event.key == gloc.K_SPACE or event.key == gloc.K_UP: # 游戏界面,小鸟存活,未暂停 # ----> 游戏开始/小鸟拍翅膀 if (not self.start and not self.ranking and not self.setting and not self.paused and self.bird.alive): self.pressed = True # 限制小鸟高度 if self.bird.rect.top > -2 * self.bird.rect.height: self.bird.fly() self.sound['wing_sound'].play() # P键/Esc键 elif event.key == gloc.K_p or event.key == gloc.K_ESCAPE: # 游戏界面,小鸟存活,未暂停 # ----> 游戏暂停/开始 if (not self.start and not self.ranking and not self.setting and self.pressed and self.bird.alive): self.paused = not self.paused # G键 elif event.key == gloc.K_g: if self.start and not hasattr(self, "ai_model"): self.init_vars(ai=True) self.ai_model = DoubleDQN() # 鼠标移动事件 elif event.type == gloc.MOUSEMOTION: # 设置界面 if self.setting and self.mouse_down: pos = pygame.mouse.get_pos() # RGB设置 # 身体 if pygame.Rect(64, 195, 40, 11).collidepoint(pos): self.body_rgb[0] = (pos[0] - 64) * 255 / 40 self.R1_set.set_point(self.body_rgb[0] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(125, 195, 40, 11).collidepoint(pos): self.body_rgb[1] = (pos[0] - 125) * 255 / 40 self.G1_set.set_point(self.body_rgb[1] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(189, 195, 40, 11).collidepoint(pos): self.body_rgb[2] = (pos[0] - 189) * 255 / 40 self.B1_set.set_point(self.body_rgb[2] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) # 嘴 elif pygame.Rect(64, 245, 40, 11).collidepoint(pos): self.mouth_rgb[0] = (pos[0] - 64) * 255 / 40 self.R2_set.set_point(self.mouth_rgb[0] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(125, 245, 40, 11).collidepoint(pos): self.mouth_rgb[1] = (pos[0] - 125) * 255 / 40 self.G2_set.set_point(self.mouth_rgb[1] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(189, 245, 40, 11).collidepoint(pos): self.mouth_rgb[2] = (pos[0] - 189) * 255 / 40 self.B2_set.set_point(self.mouth_rgb[2] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) # 音量设置 elif pygame.Rect(105, 352, 110, 15).collidepoint(pos): self.volume = (pos[0] - 105) * 100 / 110 self.volume_set.set_point(self.volume / 100) pygame.mixer.music.set_volume(self.volume * 0.4 / 100) elif pygame.Rect(105, 402, 110, 15).collidepoint(pos): self.sound_volume = (pos[0] - 105) * 100 / 110 self.sound_set.set_point(self.sound_volume / 100) for i in self.sound.keys(): self.sound[i].set_volume( self.sound_volume * self.sound_default[i] / 100) # 移出区域视为设置结束 else: self.mouse_down = False # 鼠标点击释放 elif event.type == gloc.MOUSEBUTTONUP: # 设置界面 if self.setting and self.mouse_down: self.mouse_down = False # 鼠标点击事件 elif event.type == gloc.MOUSEBUTTONDOWN: pos = event.pos # 鼠标左键 if event.button == 1: # 开始界面 if self.start: # 进入游戏界面 if self.start_image_rect.collidepoint(pos): self.start = False # 进入排行界面 elif self.score_image_rect.collidepoint(pos): self.start = False self.ranking = True # 进入设置界面 elif self.setting_image_rect.collidepoint(pos): self.start = False self.setting = True # 排行榜界面 elif self.ranking: # 回到开始界面 if self.back_rect.collidepoint(pos): self.ranking = False self.start = True # 设置界面 elif self.setting: # 回到开始界面 if self.setting_image_rect.collidepoint(pos): self.start = True self.setting = False setting.write_json(self.bird_color, self.background_index, self.volume, self.sound_volume) # 小鸟设置 elif pygame.Rect(52, 105, 30, 30)\ .collidepoint(pos): self.bird_color = (self.bird_color - 1) % 5 self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(202, 105, 30, 30)\ .collidepoint(pos): self.bird_color = (self.bird_color + 1) % 5 self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) # RGB设置 # 身体 elif pygame.Rect(64, 195, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.body_rgb[0] = (pos[0] - 64) * 255 / 40 self.R1_set.set_point(self.body_rgb[0] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(125, 195, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.body_rgb[1] = (pos[0] - 125) * 255 / 40 self.G1_set.set_point(self.body_rgb[1] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(189, 195, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.body_rgb[2] = (pos[0] - 189) * 255 / 40 self.B1_set.set_point(self.body_rgb[2] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) # 嘴 elif pygame.Rect(64, 245, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.mouth_rgb[0] = (pos[0] - 64) * 255 / 40 self.R2_set.set_point(self.mouth_rgb[0] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(125, 245, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.mouth_rgb[1] = (pos[0] - 125) * 255 / 40 self.G2_set.set_point(self.mouth_rgb[1] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) elif pygame.Rect(189, 245, 40, 11)\ .collidepoint(pos): self.mouse_down = True self.mouth_rgb[2] = (pos[0] - 189) * 255 / 40 self.B2_set.set_point(self.mouth_rgb[2] / 255) self.customize_bird.seperate( self.body_rgb, self.mouth_rgb) self.bird = bird.Bird(self.bg_size, self.land.rect.top, self.bird_color) # 背景设置 elif pygame.Rect(100, 292, 30, 30)\ .collidepoint(pos): self.background_index = ( self.background_index - 1) % 3 if self.background_index != 2: self.background = self.background_list[ self.background_index] elif pygame.Rect(200, 292, 30, 30)\ .collidepoint(pos): self.background_index = ( self.background_index + 1) % 3 if self.background_index != 2: self.background = self.background_list[ self.background_index] # 音量设置 elif pygame.Rect(105, 352, 110, 15)\ .collidepoint(pos): self.mouse_down = True self.volume = (pos[0] - 105) * 100 / 110 self.volume_set.set_point(self.volume / 100) pygame.mixer.music.set_volume(self.volume * 0.4 / 100) elif pygame.Rect(105, 402, 110, 15)\ .collidepoint(pos): self.mouse_down = True self.sound_volume = (pos[0] - 105) * 100 / 110 self.sound_set.set_point(self.sound_volume / 100) for i in self.sound.keys(): self.sound[i].set_volume( self.sound_volume * self.sound_default[i] / 100) # 分享画面 elif self.share: if self.copy_rect.collidepoint(pos): try: share.copy(self.image_data) except AttributeError: pass elif self.save_rect.collidepoint(pos): share.save(self.image_data) elif self.email_rect.collidepoint(pos): share.send_email(self.image_data, self.score) elif self.back_rect.collidepoint(pos): self.share = False # 游戏界面,小鸟存活 elif (self.pressed and self.bird.alive and self.pause_image_rect.collidepoint(pos)): self.paused = not self.paused # ----> 游戏开始/小鸟拍翅膀 elif not self.paused and self.bird.alive: self.pressed = True # 限制小鸟高度 if self.bird.rect.top > -2 * self.bird.rect.height: self.bird.fly() self.sound['wing_sound'].play() # 游戏结束界面 elif not self.bird.alive: pos = pygame.mouse.get_pos() if self.retry_rect.collidepoint(pos): self.init_vars() self.start = False elif self.share_rect.collidepoint(pos): self.image_data = pygame.surfarray.array3d( pygame.display.get_surface()) self.share = True elif self.menu_rect.collidepoint(pos): self.init_vars() # 游戏基础画面 self.screen.blit(self.background, (0, 0)) # 绘制地面 self.screen.blit(self.land.image, self.land.rect) if self.bird.alive and not self.paused: self.land.move() # 游戏开始画面 if self.start: # 绘制游戏名 self.screen.blit(self.title, self.title_rect) # 绘制开始按钮 self.screen.blit(self.start_image, self.start_image_rect) # 绘制排行按钮 self.screen.blit(self.score_image, self.score_image_rect) # 绘制设置按钮 self.screen.blit(self.setting_image, self.setting_image_rect) # 设置 elif self.setting: self.screen.blit(self.board_image, self.board_rect) self.screen.blit(self.setting_image, self.setting_image_rect) # 绘制小鸟设置 self.screen.blit(self.array_left, (52, 105)) self.screen.blit(self.array_right, (202, 105)) if self.bird_color in [0, 1, 2]: self.screen.blit( self.bird.images[self.bird.image_index(self.delay)], (120, 100)) elif self.bird_color == 3: self.screen.blit( self.random_bird[self.bird.image_index(self.delay)], (120, 100)) self.screen.blit( self.random_text, ((self.width - self.random_text.get_width()) // 2, 150)) elif self.bird_color == 4: self.screen.blit( self.bird.images[self.bird.image_index(self.delay)], (120, 100)) self.screen.blit( self.custom_text, ((self.width - self.custom_text.get_width()) // 2, 150)) self.screen.blit( self.body_text, ((self.width - self.body_text.get_width()) // 2, 170)) self.body_rgb = list(self.bird.images[0].get_at((23, 24))) self.screen.blit(self.R_text, (50, 190)) self.R1_set.set_point(self.body_rgb[0] / 255) self.R1_set.display() self.screen.blit(self.G_text, (113, 190)) self.G1_set.set_point(self.body_rgb[1] / 255) self.G1_set.display() self.screen.blit(self.B_text, (175, 190)) self.B1_set.set_point(self.body_rgb[2] / 255) self.B1_set.display() self.screen.blit( self.mouth_text, ((self.width - self.mouth_text.get_width()) // 2, 220)) self.mouth_rgb = list(self.bird.images[0].get_at((30, 27))) self.screen.blit(self.R_text, (50, 240)) self.R2_set.set_point(self.mouth_rgb[0] / 255) self.R2_set.display() self.screen.blit(self.G_text, (113, 240)) self.G2_set.set_point(self.mouth_rgb[1] / 255) self.G2_set.display() self.screen.blit(self.B_text, (175, 240)) self.B2_set.set_point(self.mouth_rgb[2] / 255) self.B2_set.display() # 绘制背景设置 self.screen.blit(self.bg_text, (50, 300)) self.screen.blit(self.array_left, (100, 292)) self.screen.blit(self.array_right, (200, 292)) self.screen.blit(self.bg_text_list[self.background_index], (150, 300)) # 绘制音量设置 self.screen.blit(self.volume_text, (50, 350)) self.volume_set.display() # 绘制音效设置 self.screen.blit(self.sound_text, (50, 400)) self.sound_set.display() # 排行界面 elif self.ranking: self.screen.blit(self.board_image, self.board_rect) if self.value is None: self.value = score.Sql.get_score() for i in range(len(self.value)): self.screen.blit(self.cups[i], self.cup_rects[i]) time_tran = time.strftime("%Y/%m/%d %H:%M:%S", time.localtime( self.value[i][0])).split() score_text = self.rank_font.render(str(self.value[i][1]), True, (0, 0, 0)) time_text1 = self.setting_font.render( time_tran[0], True, (0, 0, 0)) time_text2 = self.setting_font.render( time_tran[1], True, (0, 0, 0)) self.screen.blit( score_text, (self.cup_rects[i][0] + 50, self.cup_rects[i][1] + 10)) self.screen.blit( time_text1, (self.cup_rects[i][0] + 95, self.cup_rects[i][1] + 5)) self.screen.blit(time_text2, (self.cup_rects[i][0] + 105, self.cup_rects[i][1] + 23)) self.screen.blit(self.back_image, self.back_rect) # 分享画面 elif self.share: self.screen.blit(self.board_image, self.board_rect) self.screen.blit(self.copy_image, self.copy_rect) self.screen.blit(self.save_image, self.save_rect) self.screen.blit(self.email_image, self.email_rect) self.screen.blit(self.back_image, self.back_rect) # 游戏画面 else: # 准备画面 if not self.pressed: # 绘制小鸟 self.screen.blit( self.bird.images[self.bird.image_index(self.delay)], self.bird.rect) # 绘制ready self.screen.blit(self.ready_image, self.ready_rect) # 绘制press开始 self.screen.blit(self.press_start_image, self.press_start_rect) else: # 移动小鸟 if not self.paused: self.bird.move(self.delay) if self.ai and not self.paused and self.bird.alive: self.screen.blit(self.bg_black, (0, 0)) # 绘制pipe for upipe, dpipe in zip(self.upperpipes, self.lowerpipes): self.screen.blit(upipe.image, upipe.rect) self.screen.blit(dpipe.image, dpipe.rect) # 绘制小鸟 self.screen.blit(self.bird.image, self.bird.rect) # 绘制地面 self.screen.blit(self.land.image, self.land.rect) if self.ai and not self.paused and self.bird.alive: img = pygame.surfarray.array3d( pygame.display.get_surface()) if not hasattr(self.ai_model, "currentState"): self.ai_model.currentState = \ self.ai_model.set_initial_state(img) self.ai_model.currentState = \ self.ai_model.update_state(img) _, action_index = self.ai_model.getAction() if action_index == 1: if self.bird.rect.top > -2 * self.bird.rect.height: self.bird.fly() if self.bird.alive: # 绘制分数 score.display(self.screen, self.bg_size, self.score) if not self.paused: # 绘制暂停按钮 self.screen.blit(self.pause_image, self.pause_image_rect) # 移动pipe for upipe, dpipe in zip(self.upperpipes, self.lowerpipes): upipe.move() dpipe.move() else: # 绘制继续按钮 self.screen.blit(self.resume_image, self.resume_image_rect) # 生成和删除pipe if 0 < self.upperpipes[0].rect.left < 5: new_upipe, new_dpipe = pipe.get_pipe( self.bg_size, self.land.rect.top) self.upperpipes.append(new_upipe) self.lowerpipes.append(new_dpipe) self.pipe_group.add(new_upipe, new_dpipe) if self.upperpipes[0].rect.right < 0: self.pipe_group.remove(self.upperpipes[0], self.lowerpipes[0]) self.upperpipes.pop(0) self.lowerpipes.pop(0) # 得分 if self.bird.alive: for upipe in self.upperpipes: if (upipe.rect.centerx <= self.bird.rect.centerx < upipe.rect.centerx + 4): self.score += 1 self.sound['point_sound'].play() # 检测碰撞 if self.bird.alive and self.checkCrash(): self.bird.alive = False self.sound['hit_sound'].play() self.sound['die_sound'].play() # 游戏结束画面 if not self.bird.alive: # 绘制gameover字样 self.screen.blit(self.gameover_image, self.gameover_image_rect) # 绘制成绩面板 self.screen.blit(self.score_panel, self.score_panel_rect) score.show_score(self.screen, self.bg_size, self.score) if not self.recorded and self.value is None: self.value = score.Sql.get_score() if self.value: best_score = self.value[0][1] score.show_best(self.screen, self.bg_size, best_score) # 绘制奖牌 if self.score >= 100: self.screen.blit(self.white_medal, self.medal_rect) elif self.score >= 60: self.screen.blit(self.gold_medal, self.medal_rect) elif self.score >= 30: self.screen.blit(self.silver_medal, self.medal_rect) elif self.score >= 10: self.screen.blit(self.brooze_medal, self.medal_rect) # 绘制重新开始 self.screen.blit(self.retry_image, self.retry_rect) self.screen.blit(self.share_image, self.share_rect) self.screen.blit(self.menu_image, self.menu_rect) # 保存分数 if not self.recorded: new_record = score.Sql.set_score(self.score) self.value = score.Sql.get_score() self.recorded = True if new_record: self.screen.blit(self.new_image, self.new_rect) # 画面刷新 self.delay = (self.delay + 1) % 30 pygame.display.update() self.clock.tick(30) # AI游戏 def intelligence(self, input_action): pygame.event.pump() reward = 0.1 if sum(input_action) != 1: raise ValueError("Action Error") if input_action[1] == 1: if self.bird.rect.top > -2 * self.bird.rect.height: self.bird.fly() self.bird.move(self.delay) if self.bird.rect.top < 0: self.bird.rect.top = 0 # 移动pipe for upipe, dpipe in zip(self.upperpipes, self.lowerpipes): upipe.move() dpipe.move() # 生成和删除pipe if 0 < self.upperpipes[0].rect.left < 5: new_upipe, new_dpipe = pipe.get_pipe(self.bg_size, self.land.rect.top) self.upperpipes.append(new_upipe) self.lowerpipes.append(new_dpipe) self.pipe_group.add(new_upipe, new_dpipe) if self.upperpipes[0].rect.right < 0: self.pipe_group.remove(self.upperpipes[0], self.lowerpipes[0]) self.upperpipes.pop(0) self.lowerpipes.pop(0) # 得分 if self.bird.alive: for upipe in self.upperpipes: if (upipe.rect.centerx <= self.bird.rect.centerx < upipe.rect.centerx + 4): self.score += 1 reward = 1 # 地面碰撞 if self.bird.alive and self.checkCrash(): self.bird.alive = False self.init_vars(ai=True) reward = -1 self.screen.blit(self.bg_black, (0, 0)) # 绘制pipe for upipe, dpipe in zip(self.upperpipes, self.lowerpipes): self.screen.blit(upipe.image, upipe.rect) self.screen.blit(dpipe.image, dpipe.rect) # 绘制小鸟 self.screen.blit(self.bird.image, self.bird.rect) # 绘制地面 self.screen.blit(self.land.image, self.land.rect) if self.bird.alive: self.land.move() self.delay = (self.delay + 1) % 30 image_data = pygame.surfarray.array3d(pygame.display.get_surface()) score.display(self.screen, self.bg_size, self.score) pygame.display.update() self.clock.tick(30) return image_data, reward, not self.bird.alive
return total_reward n_episode = 200 total_reward_episode = [0] * n_episode total_mod_reward_episode = [0] * n_episode n_state = env.observation_space.shape[0] n_action = env.action_space.n n_hidden = 50 lr = 0.001 memory = deque(maxlen=10000) total_memory = deque(maxlen=1000) dqn = DoubleDQN(n_state, n_action, n_hidden, lr) q_learning(env, dqn, n_episode, gamma=0.9, epsilon=0.3, replay_size=20) plt.plot(total_reward_episode) plt.xlabel('Episode') plt.ylabel('Reward') plt.show() plt.plot(total_mod_reward_episode) plt.xlabel('Episode') plt.ylabel('Reward') plt.show() dqn.save()
def __init__(self, game, agent_type, display, load_model, record, test): self.name = game self.agent_type = agent_type self.ale = ALEInterface() self.ale.setInt(str.encode('random_seed'), np.random.randint(100)) self.ale.setBool(str.encode('display_screen'), display or record) if record: self.ale.setString(str.encode('record_screen_dir'), str.encode('./data/recordings/{}/{}/tmp/'.format(game, agent_type))) self.ale.loadROM(str.encode('./roms/{}.bin'.format(self.name))) self.action_list = list(self.ale.getMinimalActionSet()) self.frame_shape = np.squeeze(self.ale.getScreenGrayscale()).shape if test: self.name += '_test' if 'space_invaders' in self.name: # Account for blinking bullets self.frameskip = 2 else: self.frameskip = 3 self.frame_buffer = deque(maxlen=4) if load_model and not record: self.load_replaymemory() else: self.replay_memory = ReplayMemory(500000, 32) model_input_shape = self.frame_shape + (4,) model_output_shape = len(self.action_list) if agent_type == 'dqn': self.model = DeepQN( model_input_shape, model_output_shape, self.action_list, self.replay_memory, self.name, load_model ) elif agent_type == 'double': self.model = DoubleDQN( model_input_shape, model_output_shape, self.action_list, self.replay_memory, self.name, load_model ) else: self.model = DuelingDQN( model_input_shape, model_output_shape, self.action_list, self.replay_memory, self.name, load_model ) print('{} Loaded!'.format(' '.join(self.name.split('_')).title())) print('Displaying: ', display) print('Frame Shape: ', self.frame_shape) print('Frame Skip: ', self.frameskip) print('Action Set: ', self.action_list) print('Model Input Shape: ', model_input_shape) print('Model Output Shape: ', model_output_shape) print('Agent: ', agent_type)
target_update_interval=params.model_copy, clip_delta=True, update_interval=params.update_interval, batch_accumulator="mean", phi=phi, max_grad_norm=params.grad_norm ) else: agent = DoubleDQN( q_func, optimizer, rbuf, gpu=0, gamma=params.discount, explorer=explorer, replay_start_size=params.warmup, target_update_interval=params.model_copy, minibatch_size=params.BS, clip_delta=True, update_interval=params.update_interval, batch_accumulator="mean", phi=phi, max_grad_norm=params.grad_norm ) if params.RND_reward: rnd_module = RND_module(RNDNet, torch.optim.Adam, params.lr, agent) agent.set_rnd_module(rnd_module) if params.NGU_reward: ngu_module = NGU_module(Embedding_fn, Embedding_full, torch.optim.Adam, params.lr, agent, env.action_space.n, params.max_frames, params.ngu_embed_size, params.ngu_k_neighbors)
def set_model(self, model): """ Sets the model the agent is used to train. Receives a compiled tf Model with input_shape = env.observation_space and output_shape = env.action_s pace""" self.net = DoubleDQN(model)
class DQNAgent: def __init__(self, env, net_update_rate: int = 25, exploration_rate: float = 1.0, exploration_decay: float = 0.00005): # set hyper parameters self.exploration_rate = exploration_rate self.exploration_decay = exploration_decay self.net_updating_rate = net_update_rate # set environment self.env = env self.state_shape = env.get_state_shape() self.action_shape = env.get_action_shape() # the number of experience per batch for batch learning # Experience Replay for batch learning self.exp_rep = ExperienceReplay() # Deep Q Network self.net = None def set_model(self, model): """ Sets the model the agent is used to train. Receives a compiled tf Model with input_shape = env.observation_space and output_shape = env.action_s pace""" self.net = DoubleDQN(model) def get_action(self, state: np.ndarray, eps=0) -> int: """Given a state returns a random action with probability eps, and argmax(q_net(state)) with probability 1-eps. (only legal actions are considered)""" if self.net is None: raise NotImplementedError( 'agent.get_action called before model was not initiated.\n Please set the agent\'s model' ' using the set_model method. You can access the state and action shapes using ' 'agent\'s methods \'get_state_shape\' and \'get_action_shape\'' ) legal_actions = self.env.get_legal_actions(state) if np.random.random() >= eps: # Exploitation # Calculate the Q-value of each action q_values = self.net.predict(state[np.newaxis, ...], np.expand_dims(legal_actions, 0)) # Make sure we only choose between available actions legal_actions = np.logical_and(legal_actions, q_values == np.max(q_values)) return np.random.choice(np.flatnonzero(legal_actions)) def update_net(self, batch_size: int): """ if there are more than batch_size experiences, Optimizes the network's weights using the Double-Q-learning algorithm with a batch of experiences, else returns""" if self.exp_rep.get_num() < batch_size: return batch = self.exp_rep.get_batch(batch_size) self.net.fit(*batch) def train(self, episodes: int, path: str, checkpoint_rate=100, batch_size: int = 64, exp_decay_func=lambda exp_rate, exp_decay, i: 0.01 + (exp_rate - 0.01) * np.exp(exp_decay * (i + 1)), show_progress=False): """ Runs a training session for the agent :param episodes: number of episodes to train. :param path: a path to a directory where the trained weights will be saved. :param batch_size: number of experiences to learn from in each net_update. """ if self.net is None: raise NotImplementedError( 'agent.train called before model was not initiated.\n Please set the agent\'s model' ' using the set_model method. You can access the state and action shapes using ' 'agent\'s methods \'get_state_shape\' and \'get_action_shape\'' ) # set hyper parameters exploration_rate = self.exploration_rate total_rewards = [] # start training for episode in tqdm(range(episodes)): state = self.env.reset() # Reset the environment for a new episode step, episode_reward = 0, 0 run = True # Run until max actions is reached or episode has ended while run: step += 1 # choose a current action using epsilon greedy exploration action = self.get_action(state, exploration_rate) # apply the chosen action to the environment and observe the next_state and reward obs = self.env.step(action) next_state, reward, is_terminal = obs[:3] episode_reward += reward # Add experience to memory self.exp_rep.add(state, action, reward, next_state, self.env.get_legal_actions(state), is_terminal) # Optimize the DoubleQ-net self.update_net(batch_size) if is_terminal: # The action taken led to a terminal state run = False if (step % self.net_updating_rate) == 0 and step > 0: # update target network self.net.align_target_model() state = next_state # Update total_rewards to keep track of progress total_rewards.append(episode_reward) # Update target network at the end of the episode self.net.align_target_model() # Update exploration rate - exploration_rate = exp_decay_func(exploration_rate, self.exploration_decay, episode) if episode % checkpoint_rate == 0 and self.exp_rep.get_num( ) > batch_size: self.save_weights( os.path.join(path, f'episode_{episode}_weights')) if show_progress: # Plot a moving average of last 10 episodes self.plot_progress(total_rewards) # update the agents exploration rate in case more training is needed. self.exploration_rate = exploration_rate # saves the total_rewards as csv file to the path specified. with open(os.path.join(path, 'rewards.csv'), 'w') as reward_file: rewards = pd.DataFrame(total_rewards) rewards.to_csv(reward_file) self.save_weights(os.path.join(path, 'final_weights')) def plot_progress(self, total_rewards): w = np.ones(10) / 10 moving_average = np.convolve(total_rewards, w, mode='valid') plt.plot(np.arange(len(moving_average)), moving_average) plt.title('Moving average of rewards across episodes') plt.xlabel('episodes') plt.ylabel('average reward over last 10 episodes') plt.show() def get_state_shape(self): return self.state_shape def get_action_shape(self): return self.action_shape # Handles saving\loading the model as explained here: https://www.tensorflow.org/guide/keras/save_and_serialize def load_weights(self, path): self.net.load_weights(path) def save_weights(self, path): self.net.save_weights(path) def save_model(self, path): if self.net is None: raise NotImplementedError( 'agent.save_model was called before model was not initiated.\n Please set the ' 'agent\'s model using the set_model method. You can access the state and action ' 'shapes using agent\'s methods \'get_state_shape\' and \'get_action_shape\'' ) self.net.save_model(path) def load_model(self, path): model = load_model(path) self.set_model(model) def to_json(self, **kwargs): if self.net is None: raise NotImplementedError( 'agent.to_json was called before model was not initiated.\n Please set the ' 'agent\'s model using the set_model method. You can access the state and action ' 'shapes using agent\'s methods \'get_state_shape\' and \'get_action_shape\'' ) return self.net.to_json(**kwargs) def from_json(self, json_config): model = model_from_json(json_config) self.set_model(model)
def __init__(self, mode, load=True): DoubleDQN.__init__(self, action_count=14, weights_file_path=DQCNN_WEIGHTS_FILE, mode=mode, load_previous_model=load)