예제 #1
0
def Parameter_Server(Synchronizer, cluster, log_path, model_path, procs):
    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
    )
    config.gpu_options.allow_growth = True
    server = tf.train.Server(cluster,
                             job_name="ps",
                             task_index=0,
                             config=config)
    sess = tf.Session(target=server.target, config=config)
    summary_writer = tf.summary.FileWriter(log_path)
    Net = MiniNetwork(sess=sess,
                      summary_writer=summary_writer,
                      rl_training=FLAGS.training,
                      cluster=cluster,
                      index=0,
                      device=DEVICE[0 % len(DEVICE)],
                      ppo_load_path=FLAGS.restore_model_path,
                      ppo_save_path=model_path)
    Sec_Net = SecondNetwork(sess=sess,
                            rl_training=False,
                            reuse=True,
                            cluster=None,
                            index=0,
                            load_model=True,
                            net_path_name=net_path_name)

    agent = mini_source_agent.MiniSourceAgent(
        index=-1,
        net=Net,
        sec_net=Sec_Net,
        restore_model=FLAGS.restore_model,
        rl_training=FLAGS.training)

    print("Parameter server: waiting for cluster connection...")
    sess.run(tf.report_uninitialized_variables())
    print("Parameter server: cluster ready!")

    print("Parameter server: initializing variables...")
    agent.init_network()
    print("Parameter server: variables initialized")

    update_counter = 0
    max_win_rate = 0.
    while update_counter < TRAIN_ITERS:
        agent.reset_old_network()

        # wait for update
        Synchronizer.wait()
        logging("Update Network!")
        # TODO count the time , compare cpu and gpu
        time.sleep(1)

        # update finish
        Synchronizer.wait()
        logging("Update Network finished!")

        steps, win_rate = agent.update_summary(update_counter)
        logging("Steps: %d, win rate: %f" % (steps, win_rate))

        update_counter += 1
        if win_rate >= max_win_rate:
            agent.save_model()
            max_win_rate = win_rate

    return max_win_rate
예제 #2
0
def Worker(index, update_game_num, Synchronizer, cluster, model_path):
    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
    )
    config.gpu_options.allow_growth = True
    worker = tf.train.Server(cluster,
                             job_name="worker",
                             task_index=index,
                             config=config)
    sess = tf.Session(target=worker.target, config=config)

    Net = MiniNetwork(sess=sess,
                      summary_writer=None,
                      rl_training=FLAGS.training,
                      cluster=cluster,
                      index=index,
                      device=DEVICE[index % len(DEVICE)],
                      ppo_load_path=FLAGS.restore_model_path,
                      ppo_save_path=model_path)
    Sec_Net = SecondNetwork(sess=sess,
                            rl_training=False,
                            reuse=True,
                            cluster=None,
                            index=index,
                            load_model=True,
                            net_path_name=net_path_name)

    global_buffer = Buffer()
    agents = []
    for i in range(THREAD_NUM):
        agent = mini_source_agent.MiniSourceAgent(
            index=i,
            global_buffer=global_buffer,
            net=Net,
            sec_net=Sec_Net,
            restore_model=FLAGS.restore_model,
            rl_training=FLAGS.training,
            strategy_agent=None)
        agents.append(agent)

    print("Worker %d: waiting for cluster connection..." % index)
    sess.run(tf.report_uninitialized_variables())
    print("Worker %d: cluster ready!" % index)

    while len(sess.run(tf.report_uninitialized_variables())):
        print("Worker %d: waiting for variable initialization..." % index)
        time.sleep(1)
    print("Worker %d: variables initialized" % index)

    game_num = np.ceil(update_game_num // THREAD_NUM)

    UPDATE_EVENT.clear()
    ROLLING_EVENT.set()

    # Run threads
    threads = []
    for i in range(THREAD_NUM - 1):
        t = threading.Thread(target=run_thread,
                             args=(agents[i], game_num, Synchronizer,
                                   FLAGS.difficulty))
        threads.append(t)
        t.daemon = True
        t.start()
        time.sleep(3)

    run_thread(agents[-1], game_num, Synchronizer, FLAGS.difficulty)

    for t in threads:
        t.join()