def initialize_vizdoom(config): game = DoomGame() game.load_config(config) if handleArguments().demo_mode: game.set_window_visible(True) else: game.set_window_visible(False) game.set_mode(Mode.PLAYER) game.set_screen_format(ScreenFormat.GRAY8) game.set_screen_resolution(ScreenResolution.RES_640X480) game.init() return game
def run(self): total_step = 1 if handleArguments().load_model: INNER_MAX_EP = MAX_EP else: INNER_MAX_EP = mp.cpu_count() * 3 while self.inner_ep.value < INNER_MAX_EP: self.game.new_episode() state = game_state(self.game) buffer_s, buffer_a, buffer_r = [], [], [] ep_r = 0. print("initialized:", self.name) while True: start = time.time() done = False a = self.lnet.choose_action(state) if a in attack: self.action_queue.put(1) else: self.action_queue.put(0) r = self.game.make_action(actions[a], frame_repeat) if self.game.is_episode_finished(): done = True else: s_ = game_state(self.game) ep_r += r buffer_a.append(a) buffer_s.append(state) buffer_r.append(r) if done or total_step % UPDATE_GLOBAL_ITER == 0: # update network # sync optimize(opt, self.lnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA) buffer_s, buffer_a, buffer_r = [], [], [] if done: self.inner_ep.value += 1 print("Inner Ep:", self.inner_ep.value) end = time.time() time_done = end - start record(self.g_ep, self.g_ep_r, ep_r, self.res_queue, self.time_queue, self.g_time, time_done, self.name) for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()): gp._grad = lp.grad break state = s_ total_step += 1 self.time_queue.put(None) self.res_queue.put(None) self.action_queue.put(None)
self.time_queue.put(None) self.res_queue.put(None) self.action_queue.put(None) if __name__ == '__main__': print("Starting A2C-Sync Agent for Vizdoom-DeadlyCorridor") time.sleep(3) timedelta_sum = datetime.now() timedelta_sum -= timedelta_sum fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True) if handleArguments().normalized_plot and not handleArguments().save_data: runs = 3 else: runs = 1 for i in range(runs): starttime = datetime.now() # load global network if handleArguments().load_model: model = Net(len(actions)) model = torch.load("./VIZDOOM/doom_save_model/a2c_sync_doom.pt") model.eval() else: model = Net(len(actions))