Esempio n. 1
0
def single_worker_selfplay(mutable_model, immutable_model, mutable_sess,
                           immutable_sess, selfplay_data_file,
                           selfplay_kb_file, global_step, hparams,
                           summary_writer):
    """selfplay with a single worker.

  This is preminarily used for self play
  evaluation.
  """

    dialogue_mode = dialogue_utils.mode_self_play_dialogue_eval
    # Read self play data
    selfplay_data = dialogue_utils.load_data(selfplay_data_file)
    selfplay_kb = dialogue_utils.load_data(selfplay_kb_file)

    # construct dialogue object
    dialogue = SelfplayDialogue(mutable_model,
                                immutable_model,
                                mutable_sess,
                                immutable_sess,
                                hparams.max_dialogue_turns,
                                hparams.train_threadhold,
                                hparams.start_of_turn1,
                                hparams.start_of_turn2,
                                hparams.end_of_dialogue,
                                summary_writer=summary_writer,
                                dialogue_mode=dialogue_mode,
                                hparams=hparams)

    batch_size = dialogue.self_play_eval_batch_size
    assert batch_size <= len(selfplay_data)

    loaded_mutable, _ = load_self_play_model(dialogue.mutable_model,
                                             dialogue.mutable_sess, 'mutable',
                                             hparams.self_play_pretrain_dir,
                                             hparams.out_dir)
    loaded_immutable, _ = load_self_play_model(dialogue.immutable_model,
                                               dialogue.immutable_sess,
                                               'immutable',
                                               hparams.self_play_pretrain_dir,
                                               hparams.out_dir)
    worker_step = 0
    all_summary = []
    summary_weight = []  # used in combination with all_summary

    # max_eval_per_flip = 100000
    # We flip the role of the agent for exactly two times. In the first iteration
    # when flip = 0, mutable model will be agent 1 and immutable model will be
    # agent 2. The other way around when flip = 1.
    start_time = time.time()
    num_flips_for_initial_speaker = 2
    with tf.gfile.GFile(hparams.selfplay_eval_output_file,
                        'w') as selfplay_out:
        print('flip 1')
        for flip in range(num_flips_for_initial_speaker):
            # epoch = -1
            i = len(selfplay_data)  # force shuffling at the beginning
            agent1, agent2, _ = dialogue.flip_agent(
                (loaded_mutable, mutable_sess, dialogue.mutable_handles),
                (loaded_immutable, immutable_sess, dialogue.immutable_handles),
                flip)
            # only eval one epoch
            # while epoch <= 0:
            # print(i, max_eval_per_flip)
            # if i * batch_size >= len(selfplay_data):  # reacehd the end
            input_data = zip(selfplay_data, selfplay_kb)
            # we don't shuffle in evaluation
            # random.shuffle(input_data)  # random shuffle input data
            # i = 0
            selfplay_data, selfplay_kb = zip(*input_data)
            # epoch += 1
            ceil = int(math.ceil(len(selfplay_data) * 1.0 / batch_size))
            for i in tqdm(range(0, ceil)):
                start_ind = i * batch_size
                end_ind = min(i * batch_size + batch_size, len(selfplay_data))

                batch_data = selfplay_data[start_ind:end_ind]
                batch_kb = selfplay_kb[start_ind:end_ind]
                # we indicaet to let agent1 to talk first. Keep in mind that we will
                # swap between agent1 and agent2.
                speaker = flip % 2
                generated_data, _, summary = dialogue.talk(
                    hparams.max_dialogue_len, batch_data, batch_kb, agent1,
                    agent2, worker_step, batch_size, speaker)
                output_generated_data(generated_data, selfplay_out)
                all_summary.append(summary)
                # number of elements processed
                summary_weight.append(end_ind - start_ind)
                worker_step += 1
                # i += batch_size
    handle_summary(dialogue_mode, summary_writer, global_step, all_summary,
                   summary_weight)
    end_time = time.time()
    print('finished')
    utils.add_summary(summary_writer, global_step, dialogue_mode + '_time',
                      end_time - start_time)  #  step wise summary
Esempio n. 2
0
def multi_worker_selfplay(hparams,
                          identity,
                          scope=None,
                          target_session='',
                          is_chief=True,
                          ps_tasks=0,
                          num_workers=1,
                          jobid=0,
                          startup_delay_steps=0):
    """This is the multi worker selfplay, mostly used for self play

  distributed training.
  identity is used.
  """
    immutable_model_reload_freq = hparams.immutable_model_reload_freq
    # 1. models and summary writer
    model_creator = diag_model.Model
    extra_args = model_helper.ExtraArgs(
        single_cell_fn=None,
        model_device_fn=tf.train.replica_device_setter(ps_tasks),
        attention_mechanism_fn=None)

    mutable_model = model_helper.create_selfplay_model(model_creator,
                                                       is_mutable=True,
                                                       num_workers=num_workers,
                                                       jobid=jobid,
                                                       hparams=hparams,
                                                       scope=scope,
                                                       extra_args=extra_args)
    immutable_hparams = copy.deepcopy(hparams)
    immutable_hparams.num_gpus = 0
    immutable_model = model_helper.create_selfplay_model(
        model_creator,
        is_mutable=False,
        num_workers=num_workers,
        jobid=jobid,
        hparams=immutable_hparams,
        scope=scope)

    if hparams.self_play_immutable_gpu:
        print('using GPU for immutable')
        immutable_sess = tf.Session(
            graph=immutable_model.graph,
            config=tf.ConfigProto(allow_soft_placement=True))
    else:
        print('not using GPU for immutable')
        immutable_sess = tf.Session(graph=immutable_model.graph,
                                    config=tf.ConfigProto(
                                        allow_soft_placement=True,
                                        device_count={'GPU': 0}))

    immutable_model, immutable_sess = load_self_play_model(
        immutable_model, immutable_sess, 'immutable',
        hparams.self_play_pretrain_dir, hparams.out_dir)
    global_step = immutable_model.model.global_step.eval(
        session=immutable_sess)

    if is_chief:
        ckpt = tf.train.latest_checkpoint(hparams.out_dir)
        if not ckpt:
            print('global_step, saving pretrain model to hparams.out_dir',
                  global_step, hparams.out_dir)
            immutable_model.model.saver.save(  # this is the prevent adam error
                immutable_sess,
                os.path.join(hparams.out_dir, 'dialogue.ckpt'),
                global_step=global_step)
            print('save finished')

    if is_chief:
        summary_writer_path = os.path.join(
            hparams.out_dir, identity + task_SP_DISTRIBUTED + '_log')
        summary_writer = tf.summary.FileWriter(summary_writer_path,
                                               mutable_model.graph)
        print('summary writer established at', summary_writer_path)
    else:
        summary_writer = None
    # 2. supervisor and sessions

    sv = tf.train.Supervisor(
        graph=mutable_model.graph,
        is_chief=is_chief,
        saver=mutable_model.model.saver,
        save_model_secs=0,  # disable automatic save checkpoints
        summary_op=None,
        logdir=hparams.out_dir,
        checkpoint_basename='dialogue.ckpt')

    mutable_config = utils.get_config_proto(
        log_device_placement=hparams.log_device_placement,
        allow_soft_placement=True)
    mutable_config.device_count['GPU'] = hparams.num_gpus

    mutable_sess = sv.prepare_or_wait_for_session(target_session,
                                                  config=mutable_config)

    # 3. additiona preparations
    global_step = mutable_model.model.global_step.eval(session=mutable_sess)
    while global_step < (jobid * (jobid + 1) * startup_delay_steps / 2):
        time.sleep(1)
        global_step = mutable_model.model.global_step.eval(
            session=mutable_sess)

    # save first model
    if is_chief:
        print('saveing the first checkpoint to', hparams.out_dir)
        mutable_model.model.saver.save(mutable_sess,
                                       os.path.join(hparams.out_dir,
                                                    'dialogue.ckpt'),
                                       global_step=global_step)
        last_save_step = global_step

    # Read data
    selfplay_data = dialogue_utils.load_data(hparams.self_play_train_data)
    selfplay_kb = dialogue_utils.load_data(hparams.self_play_train_kb)

    dialogue = SelfplayDialogue(mutable_model,
                                immutable_model,
                                mutable_sess,
                                immutable_sess,
                                hparams.max_dialogue_turns,
                                hparams.train_threadhold,
                                hparams.start_of_turn1,
                                hparams.start_of_turn2,
                                hparams.end_of_dialogue,
                                summary_writer=summary_writer,
                                dialogue_mode=task_SP_DISTRIBUTED,
                                hparams=hparams)

    # 4. main loop
    last_immmutable_model_reload = global_step
    last_save_step = global_step
    batch_size = dialogue.batch_size
    assert batch_size <= len(selfplay_data)

    # this is the start point of the self-play data. force shuffling at the beginning
    i = len(selfplay_data)
    train_stats = [0, 0]
    while global_step < hparams.num_self_play_train_steps:
        # a. reload immutable model, muttable will be automated managed by supervisor
        if immutable_model_reload_freq > 0 and global_step - last_immmutable_model_reload > immutable_model_reload_freq:
            immutable_model, immutable_sess = load_self_play_model(
                immutable_model, immutable_sess, 'immutable',
                hparams.self_play_pretrain_dir, hparams.out_dir)
            last_immmutable_model_reload = global_step
        # b. possiblely flip between speakers (or roll out models),
        # based on either a random policy or by step counts
        agent1, agent2, mutable_agent_index = dialogue.flip_agent(
            (mutable_model, mutable_sess, dialogue.mutable_handles),
            (immutable_model, immutable_sess, dialogue.immutable_handles))
        train_stats[mutable_agent_index] += 1
        # read selfplay data
        start_time = time.time()
        if i * batch_size + batch_size > len(selfplay_data):  # reacehd the end
            input_data = zip(selfplay_data, selfplay_kb)
            random.shuffle(input_data)  # random shuffle input data
            i = 0
            selfplay_data, selfplay_kb = zip(*input_data)

        start_ind, end_ind = i * batch_size, i * batch_size + batch_size
        batch_data, batch_kb = selfplay_data[start_ind:end_ind], selfplay_kb[
            start_ind:end_ind]
        train_example, _, _ = dialogue.talk(hparams.max_dialogue_len,
                                            batch_data, batch_kb, agent1,
                                            agent2, batch_size, global_step)
        possible_global_step = dialogue.maybe_train(train_example,
                                                    mutable_agent_index,
                                                    global_step,
                                                    force=True)
        if possible_global_step:
            global_step = possible_global_step
        if is_chief and global_step - last_save_step > hparams.self_play_dist_save_freq:
            mutable_model.model.saver.save(mutable_sess,
                                           os.path.join(
                                               hparams.out_dir,
                                               'dialogue.ckpt'),
                                           global_step=global_step)
            last_save_step = global_step
        end_time = time.time()

        if is_chief:
            utils.add_summary(summary_writer, global_step,
                              task_SP_DISTRIBUTED + '_' + 'time',
                              end_time - start_time)
            utils.add_summary(summary_writer, global_step,
                              task_SP_DISTRIBUTED + '_' + 'train_ratio',
                              train_stats[0] * 1.0 / (train_stats[1] + 0.1))
        i += 1

    if is_chief:
        summary_writer.close()

    mutable_sess.close()
    immutable_sess.close()