def single_worker_inference(infer_model, infer_sess, eval_model, eval_sess, ckpt, summary_writer, global_step, hparams): """the actual function for inference.""" # load datasets infer_src_data = load_data(hparams.infer_src_data) infer_tar_data = load_data(hparams.infer_tar_data) infer_kb = load_data(hparams.infer_kb) # load model and session start_time = time.time() with infer_model.graph.as_default(): loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, infer_sess, "infer") infer_sess.run( infer_model.infer_iterator.initializer, feed_dict={ infer_model.data_src_placeholder: infer_src_data, infer_model.kb_placeholder: infer_kb, infer_model.batch_size_placeholder: hparams.infer_batch_size }) infer_handle = infer_sess.run(infer_model.infer_iterator.string_handle()) # Decode utils.print_out("# Start decoding") evaluation_scores = dialogue_utils.decode_and_evaluate( "infer", loaded_infer_model, infer_handle, infer_sess, hparams.inference_output_file, ref_file=hparams.infer_tar_data, metrics=hparams.metrics, hparams=hparams, infer_src_data=infer_src_data) # summary writer for key in evaluation_scores: # utils.add_summary(summary_writer,) utils.add_summary(summary_writer, global_step, key, evaluation_scores[key]) # sample some dialogue and decode them for qualitative examination _sample_decode(loaded_infer_model, global_step, infer_handle, infer_sess, hparams, infer_model.infer_iterator, infer_src_data, infer_tar_data, infer_kb, infer_model.data_src_placeholder, infer_model.kb_placeholder, infer_model.batch_size_placeholder) # run eval model to get perplexity eval_handle = eval_sess.run(eval_model.eval_iterator.string_handle()) dev_ppl, _ = run_internal_eval(eval_model, eval_handle, eval_sess, hparams.out_dir, hparams, summary_writer) utils.add_summary(summary_writer, global_step, "dev_ppl", dev_ppl) total_inference_time = time.time() - start_time utils.add_summary(summary_writer, global_step, "infer_time", total_inference_time)
def single_worker_selfplay(mutable_model, immutable_model, mutable_sess, immutable_sess, selfplay_data_file, selfplay_kb_file, global_step, hparams, summary_writer): """selfplay with a single worker. This is preminarily used for self play evaluation. """ dialogue_mode = dialogue_utils.mode_self_play_dialogue_eval # Read self play data selfplay_data = dialogue_utils.load_data(selfplay_data_file) selfplay_kb = dialogue_utils.load_data(selfplay_kb_file) # construct dialogue object dialogue = SelfplayDialogue(mutable_model, immutable_model, mutable_sess, immutable_sess, hparams.max_dialogue_turns, hparams.train_threadhold, hparams.start_of_turn1, hparams.start_of_turn2, hparams.end_of_dialogue, summary_writer=summary_writer, dialogue_mode=dialogue_mode, hparams=hparams) batch_size = dialogue.self_play_eval_batch_size assert batch_size <= len(selfplay_data) loaded_mutable, _ = load_self_play_model(dialogue.mutable_model, dialogue.mutable_sess, 'mutable', hparams.self_play_pretrain_dir, hparams.out_dir) loaded_immutable, _ = load_self_play_model(dialogue.immutable_model, dialogue.immutable_sess, 'immutable', hparams.self_play_pretrain_dir, hparams.out_dir) worker_step = 0 all_summary = [] summary_weight = [] # used in combination with all_summary # max_eval_per_flip = 100000 # We flip the role of the agent for exactly two times. In the first iteration # when flip = 0, mutable model will be agent 1 and immutable model will be # agent 2. The other way around when flip = 1. start_time = time.time() num_flips_for_initial_speaker = 2 with tf.gfile.GFile(hparams.selfplay_eval_output_file, 'w') as selfplay_out: print('flip 1') for flip in range(num_flips_for_initial_speaker): # epoch = -1 i = len(selfplay_data) # force shuffling at the beginning agent1, agent2, _ = dialogue.flip_agent( (loaded_mutable, mutable_sess, dialogue.mutable_handles), (loaded_immutable, immutable_sess, dialogue.immutable_handles), flip) # only eval one epoch # while epoch <= 0: # print(i, max_eval_per_flip) # if i * batch_size >= len(selfplay_data): # reacehd the end input_data = zip(selfplay_data, selfplay_kb) # we don't shuffle in evaluation # random.shuffle(input_data) # random shuffle input data # i = 0 selfplay_data, selfplay_kb = zip(*input_data) # epoch += 1 ceil = int(math.ceil(len(selfplay_data) * 1.0 / batch_size)) for i in tqdm(range(0, ceil)): start_ind = i * batch_size end_ind = min(i * batch_size + batch_size, len(selfplay_data)) batch_data = selfplay_data[start_ind:end_ind] batch_kb = selfplay_kb[start_ind:end_ind] # we indicaet to let agent1 to talk first. Keep in mind that we will # swap between agent1 and agent2. speaker = flip % 2 generated_data, _, summary = dialogue.talk( hparams.max_dialogue_len, batch_data, batch_kb, agent1, agent2, worker_step, batch_size, speaker) output_generated_data(generated_data, selfplay_out) all_summary.append(summary) # number of elements processed summary_weight.append(end_ind - start_ind) worker_step += 1 # i += batch_size handle_summary(dialogue_mode, summary_writer, global_step, all_summary, summary_weight) end_time = time.time() print('finished') utils.add_summary(summary_writer, global_step, dialogue_mode + '_time', end_time - start_time) # step wise summary
def multi_worker_selfplay(hparams, identity, scope=None, target_session='', is_chief=True, ps_tasks=0, num_workers=1, jobid=0, startup_delay_steps=0): """This is the multi worker selfplay, mostly used for self play distributed training. identity is used. """ immutable_model_reload_freq = hparams.immutable_model_reload_freq # 1. models and summary writer model_creator = diag_model.Model extra_args = model_helper.ExtraArgs( single_cell_fn=None, model_device_fn=tf.train.replica_device_setter(ps_tasks), attention_mechanism_fn=None) mutable_model = model_helper.create_selfplay_model(model_creator, is_mutable=True, num_workers=num_workers, jobid=jobid, hparams=hparams, scope=scope, extra_args=extra_args) immutable_hparams = copy.deepcopy(hparams) immutable_hparams.num_gpus = 0 immutable_model = model_helper.create_selfplay_model( model_creator, is_mutable=False, num_workers=num_workers, jobid=jobid, hparams=immutable_hparams, scope=scope) if hparams.self_play_immutable_gpu: print('using GPU for immutable') immutable_sess = tf.Session( graph=immutable_model.graph, config=tf.ConfigProto(allow_soft_placement=True)) else: print('not using GPU for immutable') immutable_sess = tf.Session(graph=immutable_model.graph, config=tf.ConfigProto( allow_soft_placement=True, device_count={'GPU': 0})) immutable_model, immutable_sess = load_self_play_model( immutable_model, immutable_sess, 'immutable', hparams.self_play_pretrain_dir, hparams.out_dir) global_step = immutable_model.model.global_step.eval( session=immutable_sess) if is_chief: ckpt = tf.train.latest_checkpoint(hparams.out_dir) if not ckpt: print('global_step, saving pretrain model to hparams.out_dir', global_step, hparams.out_dir) immutable_model.model.saver.save( # this is the prevent adam error immutable_sess, os.path.join(hparams.out_dir, 'dialogue.ckpt'), global_step=global_step) print('save finished') if is_chief: summary_writer_path = os.path.join( hparams.out_dir, identity + task_SP_DISTRIBUTED + '_log') summary_writer = tf.summary.FileWriter(summary_writer_path, mutable_model.graph) print('summary writer established at', summary_writer_path) else: summary_writer = None # 2. supervisor and sessions sv = tf.train.Supervisor( graph=mutable_model.graph, is_chief=is_chief, saver=mutable_model.model.saver, save_model_secs=0, # disable automatic save checkpoints summary_op=None, logdir=hparams.out_dir, checkpoint_basename='dialogue.ckpt') mutable_config = utils.get_config_proto( log_device_placement=hparams.log_device_placement, allow_soft_placement=True) mutable_config.device_count['GPU'] = hparams.num_gpus mutable_sess = sv.prepare_or_wait_for_session(target_session, config=mutable_config) # 3. additiona preparations global_step = mutable_model.model.global_step.eval(session=mutable_sess) while global_step < (jobid * (jobid + 1) * startup_delay_steps / 2): time.sleep(1) global_step = mutable_model.model.global_step.eval( session=mutable_sess) # save first model if is_chief: print('saveing the first checkpoint to', hparams.out_dir) mutable_model.model.saver.save(mutable_sess, os.path.join(hparams.out_dir, 'dialogue.ckpt'), global_step=global_step) last_save_step = global_step # Read data selfplay_data = dialogue_utils.load_data(hparams.self_play_train_data) selfplay_kb = dialogue_utils.load_data(hparams.self_play_train_kb) dialogue = SelfplayDialogue(mutable_model, immutable_model, mutable_sess, immutable_sess, hparams.max_dialogue_turns, hparams.train_threadhold, hparams.start_of_turn1, hparams.start_of_turn2, hparams.end_of_dialogue, summary_writer=summary_writer, dialogue_mode=task_SP_DISTRIBUTED, hparams=hparams) # 4. main loop last_immmutable_model_reload = global_step last_save_step = global_step batch_size = dialogue.batch_size assert batch_size <= len(selfplay_data) # this is the start point of the self-play data. force shuffling at the beginning i = len(selfplay_data) train_stats = [0, 0] while global_step < hparams.num_self_play_train_steps: # a. reload immutable model, muttable will be automated managed by supervisor if immutable_model_reload_freq > 0 and global_step - last_immmutable_model_reload > immutable_model_reload_freq: immutable_model, immutable_sess = load_self_play_model( immutable_model, immutable_sess, 'immutable', hparams.self_play_pretrain_dir, hparams.out_dir) last_immmutable_model_reload = global_step # b. possiblely flip between speakers (or roll out models), # based on either a random policy or by step counts agent1, agent2, mutable_agent_index = dialogue.flip_agent( (mutable_model, mutable_sess, dialogue.mutable_handles), (immutable_model, immutable_sess, dialogue.immutable_handles)) train_stats[mutable_agent_index] += 1 # read selfplay data start_time = time.time() if i * batch_size + batch_size > len(selfplay_data): # reacehd the end input_data = zip(selfplay_data, selfplay_kb) random.shuffle(input_data) # random shuffle input data i = 0 selfplay_data, selfplay_kb = zip(*input_data) start_ind, end_ind = i * batch_size, i * batch_size + batch_size batch_data, batch_kb = selfplay_data[start_ind:end_ind], selfplay_kb[ start_ind:end_ind] train_example, _, _ = dialogue.talk(hparams.max_dialogue_len, batch_data, batch_kb, agent1, agent2, batch_size, global_step) possible_global_step = dialogue.maybe_train(train_example, mutable_agent_index, global_step, force=True) if possible_global_step: global_step = possible_global_step if is_chief and global_step - last_save_step > hparams.self_play_dist_save_freq: mutable_model.model.saver.save(mutable_sess, os.path.join( hparams.out_dir, 'dialogue.ckpt'), global_step=global_step) last_save_step = global_step end_time = time.time() if is_chief: utils.add_summary(summary_writer, global_step, task_SP_DISTRIBUTED + '_' + 'time', end_time - start_time) utils.add_summary(summary_writer, global_step, task_SP_DISTRIBUTED + '_' + 'train_ratio', train_stats[0] * 1.0 / (train_stats[1] + 0.1)) i += 1 if is_chief: summary_writer.close() mutable_sess.close() immutable_sess.close()