def run_thread(agent, game_num, Synchronizer, difficulty): global UPDATE_EVENT, ROLLING_EVENT, Counter, Waiting_Counter, Update_Counter, Result_List num = 0 all_num = 0 proc_name = mp.current_process().name C._FPS = 22.4 / FLAGS.step_mul # 5.6 step_mul = FLAGS.step_mul # 4 C.difficulty = difficulty with sc2_env.SC2Env(map_name=FLAGS.map, agent_race=FLAGS.agent_race, bot_race=FLAGS.bot_race, difficulty=difficulty, step_mul=step_mul, score_index=-1, game_steps_per_episode=MAX_AGENT_STEPS, screen_size_px=(FLAGS.screen_resolution, FLAGS.screen_resolution), minimap_size_px=(FLAGS.minimap_resolution, FLAGS.minimap_resolution), visualize=False, game_version=FLAGS.game_version) as env: # env = available_actions_printer.AvailableActionsPrinter(env) agent.set_env(env) while all_num != game_num * TRAIN_ITERS: agent.play_right_add(verbose=FLAGS.debug_mode) if FLAGS.training: # check if the num of episodes is enough to update num += 1 all_num += 1 reward = agent.result['reward'] Counter += 1 Result_List.append(reward) logging( "(diff: %d) %d epoch: %s get %d/%d episodes! return: %d!" % (int(difficulty), Update_Counter, proc_name, len(Result_List), game_num * THREAD_NUM, reward)) # time for update if num == game_num: num = 0 ROLLING_EVENT.clear() # worker stops rolling, wait for update if agent.index != 0 and THREAD_NUM > 1: Waiting_Counter += 1 if Waiting_Counter == THREAD_NUM - 1: # wait for all the workers stop UPDATE_EVENT.set() ROLLING_EVENT.wait() # update! else: if THREAD_NUM > 1: UPDATE_EVENT.wait() Synchronizer.wait( ) # wait for other processes to update agent.update_network(Result_List) Result_List.clear() agent.global_buffer.reset() Synchronizer.wait() Update_Counter += 1 # finish update UPDATE_EVENT.clear() Waiting_Counter = 0 ROLLING_EVENT.set() if FLAGS.save_replay: env.save_replay(FLAGS.replay_dir) agent.reset()
def run_thread(agent, Synchronizer): global COUNTER, WAITING_COUNTER, GAME_NUM, PER_GAME_NUM C._FPS = 2.8 step_mul = FLAGS.step_mul for difficulty in Difficulty_list: with sc2_env.SC2Env( map_name=FLAGS.map, agent_race=FLAGS.agent_race, bot_race=FLAGS.bot_race, difficulty=difficulty, step_mul=step_mul, score_index=-1, screen_size_px=(FLAGS.screen_resolution, FLAGS.screen_resolution), minimap_size_px=(FLAGS.minimap_resolution, FLAGS.minimap_resolution), visualize=False, game_steps_per_episode=900 * 22.4, game_version=FLAGS.game_version) as env: # Only for a single player! agent.set_env(env) if difficulty == "A": C.difficulty = 10 else: C.difficulty = difficulty for j in range(PER_GAME_NUM): agent.play() reward = agent.result['reward'] with LOCK: RESULT_ARRAY[Difficulty_list.index(difficulty), Reward_list.index(reward)] += 1 COUNTER += 1 print("difficulty %s: finished %d games!" % (difficulty, COUNTER)) agent.reset() time.sleep(2) if ROLLING_EVENT.is_set(): ROLLING_EVENT.clear() WAITING_COUNTER += 1 if WAITING_COUNTER == PARALLEL: UPDATE_EVENT.set() if agent.index == 0: UPDATE_EVENT.wait() win = RESULT_ARRAY[Difficulty_list.index(difficulty), Reward_list.index(1)] fair = RESULT_ARRAY[Difficulty_list.index(difficulty), Reward_list.index(0)] lose = RESULT_ARRAY[Difficulty_list.index(difficulty), Reward_list.index(-1)] log_path = "./result/" + FLAGS.agent_race + 'v' + FLAGS.bot_race + '_' + \ FLAGS.restore_model_path.split('/')[-2] + '_' + FLAGS.map + '.txt' log_file = open(log_path, "a") log_file.write('difficulty: %s, game_num: %d\n' % (difficulty, GAME_NUM)) log_file.write('win: %d, %.2f\n' % (int(win), win / GAME_NUM)) log_file.write('fair: %d, %.2f\n' % (int(fair), fair / GAME_NUM)) log_file.write('loss: %d, %.2f\n\n' % (int(lose), lose / GAME_NUM)) log_file.close() UPDATE_EVENT.clear() ROLLING_EVENT.set() WAITING_COUNTER = 0 COUNTER = 0 ROLLING_EVENT.wait() if agent.index == 0: Synchronizer.wait()