示例#1
0
def train():
    """
    init dir and log config
    """
    init_cluster_ray(log_to_driver=(hvd.local_rank() == 0))
    hvd.init()
    base_dir, ckpt_dir, summary_dir = init_dir_and_log()

    kwargs = FLAGS.flag_values_dict()
    kwargs["BASE_DIR"] = base_dir
    kwargs["ckpt_dir"] = ckpt_dir
    """
    get one seg from rollout worker for dtype and shapes

    :param kwargs rollout worker config
    """
    logging.info('get one seg from Evaluator for dtype and shapes')
    ps = AsyncPS.remote()
    small_data_collector = RolloutCollector(
        server_nums=1,
        ps=ps,
        policy_evaluator_build_func=build_policy_evaluator,
        **kwargs)
    cache_struct_path = '/tmp/ppo%s.pkl' % FLAGS.dir

    if hvd.local_rank() == 0:
        structure = fetch_one_structure(small_data_collector,
                                        cache_struct_path=cache_struct_path,
                                        is_head=True)
    else:
        structure = fetch_one_structure(small_data_collector,
                                        cache_struct_path=cache_struct_path,
                                        is_head=False)

    del small_data_collector
    """
        init data prefetch thread, prepare_input_pipe
    """
    keys = list(structure.keys())
    dtypes = [structure[k].dtype for k in keys]
    shapes = [structure[k].shape for k in keys]
    segBuffer = tf.queue.FIFOQueue(capacity=FLAGS.qsize * FLAGS.batch_size,
                                   dtypes=dtypes,
                                   shapes=shapes,
                                   names=keys,
                                   shared_name="buffer")

    server_nums = FLAGS.nof_evaluator
    server_nums_refine = server_nums * 2 // FLAGS.cpu_per_actor
    nof_server_gpus = FLAGS.nof_server_gpus
    server_nums_refine = server_nums_refine // nof_server_gpus
    data_collector = RolloutCollector(
        server_nums=server_nums_refine,
        ps=ps,
        policy_evaluator_build_func=build_policy_evaluator,
        **kwargs)

    config = tf.ConfigProto(
        allow_soft_placement=True,
        gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=1))
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)
    reader = QueueReader(sess=sess,
                         global_queue=segBuffer,
                         data_collector=data_collector,
                         keys=keys,
                         dtypes=dtypes,
                         shapes=shapes)
    reader.daemon = True
    reader.start()

    dequeued = segBuffer.dequeue_many(FLAGS.batch_size)
    prephs, postphs = dict(), dict()
    for k, v in dequeued.items():
        if k == "state_in":
            prephs[k] = v
        else:
            prephs[k], postphs[k] = tf.split(v, [FLAGS.burn_in, FLAGS.seqlen],
                                             axis=1)
    prekeys = list(prephs.keys())
    postkeys = list(postphs.keys())

    ##  count frame and total steps
    num_frames = tf.get_variable('num_environment_frames',
                                 initializer=tf.zeros_initializer(),
                                 shape=[],
                                 dtype=tf.int32,
                                 trainable=False)
    tf.summary.scalar("frames", num_frames)
    global_step = tf.train.get_or_create_global_step()

    dur_time_tensor = tf.placeholder(dtype=tf.float32)
    tf.summary.scalar('time_per_step', dur_time_tensor)

    #  set stage_op and build learner
    with tf.device("/gpu"):
        if FLAGS.use_stage:
            area = tf.contrib.staging.StagingArea(
                [prephs[key].dtype for key in prekeys] +
                [postphs[key].dtype
                 for key in postkeys], [prephs[key].shape for key in prekeys] +
                [postphs[key].shape for key in postkeys])
            stage_op = area.put([prephs[key] for key in prekeys] +
                                [postphs[key] for key in postkeys])
            from_stage = area.get()
            predatas = {key: from_stage[i] for i, key in enumerate(prekeys)}
            postdatas = {
                key: from_stage[i + len(prekeys)]
                for i, key in enumerate(postkeys)
            }
        else:
            stage_op = []
            predatas, postdatas = prephs, postphs

        act_space = FLAGS.act_space
        num_frames_and_train, global_step_and_train = build_learner(
            pre=predatas,
            post=postdatas,
            act_space=act_space,
            num_frames=num_frames)
    """
        add summary
    """
    summary_ops = tf.summary.merge_all()
    if hvd.local_rank() == 0:
        summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)
    """
        initialize and save ckpt
    """
    saver = tf.train.Saver(max_to_keep=100, keep_checkpoint_every_n_hours=6)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        sess.run(tf.global_variables_initializer())
    ws = Model.get_ws(sess)
    logging.info('pushing weight to ps')
    ray.get(ps.push.remote(ws))

    sess.run(hvd.broadcast_global_variables(0))
    if hvd.local_rank() == 0:
        saver.save(sess,
                   os.path.join(ckpt_dir, "PPOcGAE"),
                   global_step=global_step)
    """
        step
    """
    total_frames = 0
    sess.run(stage_op)

    dur_time = 0
    while total_frames < FLAGS.total_environment_frames:
        start = time.time()

        total_frames, gs, summary, _ = sess.run([
            num_frames_and_train, global_step_and_train, summary_ops, stage_op
        ],
                                                feed_dict={
                                                    dur_time_tensor: dur_time
                                                })

        if hvd.local_rank() == 0:
            if gs % 1 == 0:
                summary_writer.add_summary(summary, global_step=gs)
                dur_time = time.time() - start
                msg = "Global Step %d, Total Frames %d,  Time Consume %.2f" % (
                    gs, total_frames, dur_time)
                logging.info(msg)
            if gs % 25 == 0:
                ws = Model.get_ws(sess)
                logging.info('pushing weight to ps')
                ray.get(ps.push.remote(ws))
            if gs % 1000 == 0:
                saver.save(sess,
                           os.path.join(ckpt_dir, "CKPT"),
                           global_step=global_step)

    if hvd.local_rank() == 0:
        saver.save(sess,
                   os.path.join(ckpt_dir, "CKPT"),
                   global_step=global_step)
示例#2
0
def train():
    """
    init dir and log config
    """
    init_cluster_ray()
    base_dir, ckpt_dir, summary_dir = init_dir_and_log()

    kwargs = FLAGS.flag_values_dict()
    kwargs["BASE_DIR"] = base_dir
    kwargs["ckpt_dir"] = ckpt_dir
    act_space = int(FLAGS.act_space)
    kwargs["act_space"] = act_space
    """
    get one seg from rollout worker for dtype and shapes

    :param kwargs rollout worker config
    """
    logging.info('get one seg from Evaluator for dtype and shapes')
    ps = AsyncPS.remote()
    small_data_collector = RolloutCollector(
        server_nums=1,
        ps=ps,
        policy_evaluator_build_func=build_policy_evaluator,
        **kwargs)
    cache_struct_path = '/tmp/%s.pkl' % FLAGS.dir
    structure = fetch_one_structure(small_data_collector,
                                    cache_struct_path=cache_struct_path,
                                    is_head=True)
    del small_data_collector
    """
        init data prefetch thread, prepare_input_pipe
    """
    keys = list(structure.keys())
    dtypes = [structure[k].dtype for k in keys]
    shapes = [structure[k].shape for k in keys]
    segBuffer = tf.queue.FIFOQueue(capacity=FLAGS.qsize * FLAGS.batch_size,
                                   dtypes=dtypes,
                                   shapes=shapes,
                                   names=keys,
                                   shared_name="buffer")

    server_nums = FLAGS.nof_evaluator
    nof_server_gpus = FLAGS.nof_server_gpus
    server_nums_refine = server_nums // nof_server_gpus
    data_collector = RolloutCollector(
        server_nums=server_nums_refine,
        ps=ps,
        policy_evaluator_build_func=build_policy_evaluator,
        **kwargs)

    config = tf.ConfigProto(
        allow_soft_placement=True,
        gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=1))
    config.gpu_options.allow_growth = True

    sess = tf.Session(config=config)

    reader = QueueReader(sess=sess,
                         global_queue=segBuffer,
                         data_collector=data_collector,
                         keys=keys,
                         dtypes=dtypes,
                         shapes=shapes)
    reader.daemon = True
    reader.start()

    dequeued = segBuffer.dequeue_many(FLAGS.batch_size)

    # //////////////////////
    demo_buffer = build_demo_buffer(keys, 0.9)
    # //////////////////////
    replay_buffer = PrioritizedReplayBuffer(10000, keys, 0.9)
    weights = tf.placeholder(dtype=tf.float32, shape=[None])

    phs = {
        key: tf.placeholder(dtype=dtype, shape=[None] + list(shape))
        for key, dtype, shape in zip(keys, dtypes, shapes)
    }

    prephs, postphs = dict(), dict()
    for k, v in phs.items():
        if k == "state_in":
            prephs[k] = v
        else:
            prephs[k], postphs[k] = tf.split(
                v, [FLAGS.burn_in, FLAGS.seqlen + FLAGS.n_step], axis=1)
    prekeys = list(prephs.keys())
    postkeys = list(postphs.keys())
    """
        count frame and total steps
    """
    num_frames = tf.get_variable('num_environment_frames',
                                 initializer=tf.zeros_initializer(),
                                 shape=[],
                                 dtype=tf.int32,
                                 trainable=False)
    tf.summary.scalar("frames", num_frames)
    global_step = tf.train.get_or_create_global_step()

    dur_time_tensor = tf.placeholder(dtype=tf.float32)
    tf.summary.scalar('time_per_step', dur_time_tensor)
    """
        set stage_op and build learner
    """
    with tf.device("/gpu"):
        if FLAGS.use_stage:
            area = tf.contrib.staging.StagingArea(
                [prephs[key].dtype for key in prekeys] +
                [postphs[key].dtype
                 for key in postkeys], [prephs[key].shape for key in prekeys] +
                [postphs[key].shape for key in postkeys])
            stage_op = area.put([prephs[key] for key in prekeys] +
                                [postphs[key] for key in postkeys])
            from_stage = area.get()
            predatas = {key: from_stage[i] for i, key in enumerate(prekeys)}
            postdatas = {
                key: from_stage[i + len(prekeys)]
                for i, key in enumerate(postkeys)
            }
        else:
            stage_op = []
            predatas, postdatas = prephs, postphs

        num_frames_and_train, global_step_and_train, target_op, priority, beta = build_learner(
            pre=predatas,
            post=postdatas,
            ws=weights,
            act_space=act_space,
            num_frames=num_frames)
    """
        add summary
    """
    summary_ops = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)
    """
        initialize and save ckpt
    """
    saver = tf.train.Saver(max_to_keep=100, keep_checkpoint_every_n_hours=6)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        sess.run(tf.global_variables_initializer())
    ws = Model.get_ws(sess)
    logging.info('pushing weight to ps')
    ray.get(ps.push.remote(ws))

    saver.save(sess, os.path.join(ckpt_dir, "CKPT"), global_step=global_step)
    """
        step
    """
    total_frames = 0
    sess.run(target_op)
    dequeued_datas, sample_beta = sess.run([dequeued, beta])
    replay_buffer.add_batch(dequeued_datas, FLAGS.batch_size)

    dur_time = 0
    while total_frames < FLAGS.total_environment_frames:
        start = time.time()

        batch_size = np.random.binomial(FLAGS.batch_size - 2, 0.99) + 1
        demo_batch_size = FLAGS.batch_size - batch_size

        datas, is_weights, idxes = replay_buffer.sample(
            batch_size, sample_beta)
        demo_datas, demo_is_weights, demo_idxes = demo_buffer.sample(
            demo_batch_size, sample_beta)

        fd = {
            phs[k]: np.concatenate([datas[k], demo_datas[k]], axis=0)
            for k in keys
        }
        fd[weights] = np.concatenate([is_weights, demo_is_weights], axis=0)
        fd[dur_time_tensor] = dur_time

        total_frames, gs, summary, _, p, sample_beta = sess.run([
            num_frames_and_train, global_step_and_train, summary_ops, stage_op,
            priority, beta
        ],
                                                                feed_dict=fd)

        replay_buffer.update_priorities(idxes, p[:batch_size])
        demo_buffer.update_priorities(demo_idxes, p[batch_size:])

        if gs % 4 == 0:
            dequeued_datas = sess.run(dequeued)
            replay_buffer.add_batch(dequeued_datas, FLAGS.batch_size)

        if gs % FLAGS.target_update == 0:
            sess.run(target_op)

        if gs % 25 == 0:
            ws = Model.get_ws(sess)
            with open("/opt/tiger/test_ppo/ws.pkl", "wb") as f:
                pickle.dump(ws, f)
            logging.info('pushing weight to ps')
            try:
                ray.get(ps.push.remote(ws))
            except ray.exceptions.UnreconstructableError as e:
                logging.info(str(e))
            except ray.exceptions.RayError as e:
                logging.info(str(e))

        if gs % 1000 == 0:
            saver.save(sess,
                       os.path.join(ckpt_dir, "CKPT"),
                       global_step=global_step)

        if gs % 1 == 0:
            summary_writer.add_summary(summary, global_step=gs)
            dur_time = time.time() - start
            msg = "Global Step %d, Total Frames %d,  Time Consume %.2f" % (
                gs, total_frames, dur_time)
            logging.info(msg)

    saver.save(sess, os.path.join(ckpt_dir, "CKPT"), global_step=global_step)
示例#3
0
def train():
    """
    init dir and log config
    """
    init_cluster_ray()
    base_dir, ckpt_dir, summary_dir = init_dir_and_log()

    kwargs = FLAGS.flag_values_dict()
    kwargs["BASE_DIR"] = base_dir
    kwargs["ckpt_dir"] = ckpt_dir
    games = get_games()
    kwargs["game"] = games[kwargs["game"]]
    env = gym.make(kwargs["game"])
    act_space = env.action_space.n
    kwargs["act_space"] = act_space
    del env
    """
    get one seg from rollout worker for dtype and shapes

    :param kwargs rollout worker config
    """
    logging.info('get one seg from Evaluator for dtype and shapes')
    ps = AsyncPS.remote()
    small_data_collector = RolloutCollector(
        server_nums=1,
        ps=ps,
        policy_evaluator_build_func=build_policy_evaluator,
        **kwargs)
    cache_struct_path = '/tmp/%s_%s.pkl' % (FLAGS.dir, kwargs["game"])
    structure = fetch_one_structure(small_data_collector,
                                    cache_struct_path=cache_struct_path,
                                    is_head=True)
    del small_data_collector
    """
        init data prefetch thread, prepare_input_pipe
    """
    keys = list(structure.keys())
    dtypes = [structure[k].dtype for k in keys]
    shapes = [structure[k].shape for k in keys]
    segBuffer = tf.queue.RandomShuffleQueue(
        capacity=FLAGS.qsize * FLAGS.batch_size,
        min_after_dequeue=FLAGS.qsize * FLAGS.batch_size // 2,
        dtypes=dtypes,
        shapes=shapes,
        names=keys,
        shared_name="buffer")
    server_nums = FLAGS.nof_evaluator
    nof_server_gpus = FLAGS.nof_server_gpus
    server_nums_refine = server_nums // nof_server_gpus
    data_collector = RolloutCollector(
        server_nums=server_nums_refine,
        ps=ps,
        policy_evaluator_build_func=build_policy_evaluator,
        **kwargs)

    config = tf.ConfigProto(
        allow_soft_placement=True,
        gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=1))
    config.gpu_options.allow_growth = True

    sess = tf.Session(config=config)

    reader = QueueReader(sess=sess,
                         global_queue=segBuffer,
                         data_collector=data_collector,
                         keys=keys,
                         dtypes=dtypes,
                         shapes=shapes)
    reader.daemon = True
    reader.start()

    dequeued = segBuffer.dequeue_many(FLAGS.batch_size)

    # //////////////////////
    from_where = dequeued
    batch_weights = tf.ones(FLAGS.batch_size)

    # //////////////////////

    prephs, postphs = dict(), dict()
    for k, v in from_where.items():
        if k == "state_in":
            prephs[k] = v
        else:
            prephs[k], postphs[k] = tf.split(
                v, [FLAGS.burn_in, FLAGS.seqlen + FLAGS.n_step], axis=1)
    prekeys = list(prephs.keys())
    postkeys = list(postphs.keys())
    """
        count frame and total steps
    """
    num_frames = tf.get_variable('num_environment_frames',
                                 initializer=tf.zeros_initializer(),
                                 shape=[],
                                 dtype=tf.int32,
                                 trainable=False)
    tf.summary.scalar("frames", num_frames)
    global_step = tf.train.get_or_create_global_step()

    dur_time_tensor = tf.placeholder(dtype=tf.float32)
    tf.summary.scalar('time_per_step', dur_time_tensor)
    """
        set stage_op and build learner
    """
    with tf.device("/gpu"):
        if FLAGS.use_stage:
            area = tf.contrib.staging.StagingArea(
                [prephs[key].dtype for key in prekeys] +
                [postphs[key].dtype
                 for key in postkeys], [prephs[key].shape for key in prekeys] +
                [postphs[key].shape for key in postkeys])
            stage_op = area.put([prephs[key] for key in prekeys] +
                                [postphs[key] for key in postkeys])
            from_stage = area.get()
            predatas = {key: from_stage[i] for i, key in enumerate(prekeys)}
            postdatas = {
                key: from_stage[i + len(prekeys)]
                for i, key in enumerate(postkeys)
            }
        else:
            stage_op = []
            predatas, postdatas = prephs, postphs

        num_frames_and_train, global_step_and_train, init_target_op, priority, beta = build_learner(
            pre=predatas,
            post=postdatas,
            act_space=act_space,
            num_frames=num_frames,
            batch_weights=batch_weights)
    """
        add summary
    """
    summary_ops = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)
    """
        initialize and save ckpt
    """
    saver = tf.train.Saver(max_to_keep=100, keep_checkpoint_every_n_hours=6)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        sess.run(tf.global_variables_initializer())
    ws = Model.get_ws(sess)
    logging.info('pushing weight to ps')
    ray.get(ps.push.remote(ws))

    saver.save(sess, os.path.join(ckpt_dir, "CKPT"), global_step=global_step)
    """
        step
    """
    total_frames = 0
    sess.run(stage_op)
    sess.run(init_target_op)
    dur_time = 0
    while total_frames < FLAGS.total_environment_frames:
        start = time.time()

        fd = {dur_time_tensor: dur_time}

        total_frames, gs, summary, _ = sess.run([
            num_frames_and_train, global_step_and_train, summary_ops, stage_op
        ],
                                                feed_dict=fd)

        if gs % FLAGS.target_update == 0:
            sess.run(init_target_op)

        if gs % 25 == 0:
            ws = Model.get_ws(sess)
            logging.info('pushing weight to ps')
            try:
                ray.get(ps.push.remote(ws))
            except ray.exceptions.UnreconstructableError as e:
                logging.info(str(e))
            except ray.exceptions.RayError as e:
                logging.info(str(e))

        if gs % 1000 == 0:
            saver.save(sess,
                       os.path.join(ckpt_dir, "CKPT"),
                       global_step=global_step)

        if gs % 1 == 0:
            summary_writer.add_summary(summary, global_step=gs)
            dur_time = time.time() - start
            msg = "Global Step %d, Total Frames %d,  Time Consume %.2f" % (
                gs, total_frames, dur_time)
            logging.info(msg)

    saver.save(sess, os.path.join(ckpt_dir, "CKPT"), global_step=global_step)

    nof_workers = os.getenv('ARNOLD_WORKER_NUM', None)
    assert nof_workers is not None, nof_workers

    for i in range(int(nof_workers)):
        print('killing worker %s' % i)
        os.system('ssh worker-%s pkill run.sh' % i)