def Parameter_Server(Synchronizer, cluster, log_path, model_path, procs): config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False, ) config.gpu_options.allow_growth = True server = tf.train.Server(cluster, job_name="ps", task_index=0, config=config) sess = tf.Session(target=server.target, config=config) summary_writer = tf.summary.FileWriter(log_path) Net = MiniNetwork(sess=sess, summary_writer=summary_writer, rl_training=FLAGS.training, cluster=cluster, index=0, device=DEVICE[0 % len(DEVICE)], ppo_load_path=FLAGS.restore_model_path, ppo_save_path=model_path) Sec_Net = SecondNetwork(sess=sess, rl_training=False, reuse=True, cluster=None, index=0, load_model=True, net_path_name=net_path_name) agent = mini_source_agent.MiniSourceAgent( index=-1, net=Net, sec_net=Sec_Net, restore_model=FLAGS.restore_model, rl_training=FLAGS.training) print("Parameter server: waiting for cluster connection...") sess.run(tf.report_uninitialized_variables()) print("Parameter server: cluster ready!") print("Parameter server: initializing variables...") agent.init_network() print("Parameter server: variables initialized") update_counter = 0 max_win_rate = 0. while update_counter < TRAIN_ITERS: agent.reset_old_network() # wait for update Synchronizer.wait() logging("Update Network!") # TODO count the time , compare cpu and gpu time.sleep(1) # update finish Synchronizer.wait() logging("Update Network finished!") steps, win_rate = agent.update_summary(update_counter) logging("Steps: %d, win rate: %f" % (steps, win_rate)) update_counter += 1 if win_rate >= max_win_rate: agent.save_model() max_win_rate = win_rate return max_win_rate
def Worker(index, update_game_num, Synchronizer, cluster, model_path): config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False, ) config.gpu_options.allow_growth = True worker = tf.train.Server(cluster, job_name="worker", task_index=index, config=config) sess = tf.Session(target=worker.target, config=config) Net = MiniNetwork(sess=sess, summary_writer=None, rl_training=FLAGS.training, cluster=cluster, index=index, device=DEVICE[index % len(DEVICE)], ppo_load_path=FLAGS.restore_model_path, ppo_save_path=model_path) Sec_Net = SecondNetwork(sess=sess, rl_training=False, reuse=True, cluster=None, index=index, load_model=True, net_path_name=net_path_name) global_buffer = Buffer() agents = [] for i in range(THREAD_NUM): agent = mini_source_agent.MiniSourceAgent( index=i, global_buffer=global_buffer, net=Net, sec_net=Sec_Net, restore_model=FLAGS.restore_model, rl_training=FLAGS.training, strategy_agent=None) agents.append(agent) print("Worker %d: waiting for cluster connection..." % index) sess.run(tf.report_uninitialized_variables()) print("Worker %d: cluster ready!" % index) while len(sess.run(tf.report_uninitialized_variables())): print("Worker %d: waiting for variable initialization..." % index) time.sleep(1) print("Worker %d: variables initialized" % index) game_num = np.ceil(update_game_num // THREAD_NUM) UPDATE_EVENT.clear() ROLLING_EVENT.set() # Run threads threads = [] for i in range(THREAD_NUM - 1): t = threading.Thread(target=run_thread, args=(agents[i], game_num, Synchronizer, FLAGS.difficulty)) threads.append(t) t.daemon = True t.start() time.sleep(3) run_thread(agents[-1], game_num, Synchronizer, FLAGS.difficulty) for t in threads: t.join()