def secondary_fn_tmp(hparams, identity, model_dir, model, eval_model, eval_sess, name, worker_fn): """secondary helper function for inference and evaluation.""" steps_per_external_eval = 10 # initialize summary writer summary_writer_path = os.path.join(hparams.out_dir, identity + name + "_log") print("summary_writer_path", summary_writer_path) summary_writer = tf.summary.FileWriter(summary_writer_path, model.graph) config_proto = utils.get_config_proto( log_device_placement=hparams.log_device_placement, allow_soft_placement=True) # create session sess = tf.Session(config=config_proto, graph=model.graph) # wait for the checkpoints latest_ckpt = None last_external_eval_step = 0 # main inference loop while True: latest_ckpt = tf.contrib.training.wait_for_new_checkpoint( model_dir, latest_ckpt) with model.graph.as_default(): _, global_step = model_helper.create_or_load_model( model.model, model_dir, sess, name) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step worker_fn(model, sess, eval_model, eval_sess, latest_ckpt, summary_writer, global_step, hparams) if not hparams.eval_forever: break # if eval_foever is disabled, we only evaluate once summary_writer.close() sess.close()
def get_answer(self, question, id_spk): infer_data = [clean_text(question)] infer_model = self.infer_model with tf.Session( graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_helper.load_model( infer_model.model, self.ckpt, sess, "infer") sess.run( infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: 1, infer_model.src_speaker_placeholder: id_spk, infer_model.tgt_speaker_placeholder: self.id }) nmt_outputs, _ = loaded_infer_model.decode(sess) translation = [] for beam_id in range(self.num_translations_per_input): # Set set_id to 0 because batch_size of 1 translation.append(nmt_utils.get_translation( nmt_outputs=nmt_outputs[beam_id], sent_id=0, tgt_eos=self.hparams.eos, subword_option=self.hparams.subword_option)) return translation
def __init__(self, hparams): self.hparams = hparams # print("====test__init__==\n") # Data locations self.out_dir = hparams.out_dir # print("our_dir:", self.out_dir) self.model_dir = os.path.join(self.out_dir, 'ckpts') # print("model_dir:", self.model_dir) # Create models attention_option = hparams.attention_option if attention_option: model_creator = AttentionModel else: model_creator = BasicModel self.infer_model = model_helper.create_infer_model( hparams=hparams, model_creator=model_creator) # Sessions config_proto = utils.get_config_proto() self.infer_sess = tf.Session(config=config_proto, graph=self.infer_model.graph) # EOS self.tgt_eos = Vocabulary.EOS.encode("utf-8") # Load infer model with self.infer_model.graph.as_default(): self.loaded_infer_model, self.global_step = model_helper.create_or_load_model( self.infer_model.model, self.model_dir, self.infer_sess, "infer")
def start_sess_and_load_model(infer_model, ckpt_path): """Start session and load model.""" sess = tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) with infer_model.graph.as_default(): loaded_infer_model = model_helper.load_model(infer_model.model, ckpt_path, sess, "infer") return sess, loaded_infer_model
def single_worker_inference(infer_model, ckpt, inference_input_file, inference_output_file, hparams): """Inference with a single worker.""" output_infer = inference_output_file # Read data infer_data = load_data(inference_input_file, hparams) with tf.Session(config=utils.get_config_proto(), graph=infer_model.graph) as sess: loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size }) # Decode utils.print_out("# Start decoding") _decode_and_evaluate("infer", loaded_infer_model, sess, output_infer, ref_file=None, subword_option=None, beam_width=hparams.beam_width, tgt_eos=hparams.eos, num_translations_per_input=hparams.num_translations_per_input)
def infer(hparams): infer_model = mc.create_infer_model(hparams) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=hparams.log_device_placement, allow_soft_placement=True) infer_sess = tf.Session(graph=infer_model.graph, config=config_proto) infer_util.run_infer(hparams, infer_sess, infer_model, 0, ["bleu", "rouge", "accuracy"])
def export_model(config, model_creator): if not config.export_path: raise ValueError("Export path must be specified.") if not config.model_version: raise ValueError("Export model version must be specified.") utils.makedir(config.export_path) # Create model model = model_helper.create_model(model_creator, config, mode="infer") # TensorFlow model config_proto = utils.get_config_proto() sess = tf.Session(config=config_proto, graph=model.graph) with model.graph.as_default(): loaded_model, global_step = model_helper.create_or_load_model( model.model, config.best_eval_loss_dir, sess, "infer") export_dir = os.path.join(config.export_path, config.model_version) builder = tf.saved_model.builder.SavedModelBuilder(export_dir) inputs = { "word_ids1": tf.saved_model.utils.build_tensor_info(loaded_model.word_ids1), "word_ids2": tf.saved_model.utils.build_tensor_info(loaded_model.word_ids2), "word_len1": tf.saved_model.utils.build_tensor_info(loaded_model.word_len1), "word_len2": tf.saved_model.utils.build_tensor_info(loaded_model.word_len2), "char_ids1": tf.saved_model.utils.build_tensor_info(loaded_model.char_ids1), "char_ids2": tf.saved_model.utils.build_tensor_info(loaded_model.char_ids2), "char_len1": tf.saved_model.utils.build_tensor_info(loaded_model.char_len1), "char_len2": tf.saved_model.utils.build_tensor_info(loaded_model.char_len2) } outputs = { "simscore": tf.saved_model.utils.build_tensor_info(loaded_model.simscore) } prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs=inputs, outputs=outputs, method_name=tf.saved_model.signature_constants. PREDICT_METHOD_NAME)) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], { tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature }) builder.save() logger.info("Export model succeed.")
def run_sample_decode_pungan_prepare(hparams, scope=None, target_session=""): log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir if not hparams.attention: # choose this model model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) wsd_src_file = "%s" % (hparams.sample_prefix) wsd_src_data = inference.load_data(wsd_src_file) model_dir = hparams.out_dir # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) eval_result = run_sample_decode_pungan(infer_model, infer_sess, model_dir, hparams, wsd_src_data) print('eval_result', eval_result) print('eval_result.len', len(eval_result)) ''' for i in range(0,256): if i%32 == 0: print('\n') print('eval_result[',i,']',eval_result[i]) ''' print('wsd_src_data', wsd_src_data) print('wsd_src_data.len', len(wsd_src_data)) #len=16 wsd_src_data [u'problem%1:10:00::', u'problem%1:26:00::', u'drive%1:04:00::', u'drive%1:04:03::', u'identity%1:07:00::', u'identity%1:24:01::', u'point%1:10:01::', u'point%1:06:00::', u'tension%1:26:01::', u'tension%1:26:03::', u'log%2:32:00::', u'log%2:35:00::', u'fan%1:06:00::', u'fan%1:18:00::', u'file%2:32:00::', u'file%2:35:00::'] eval_result_new = [] for block in range(len(eval_result) / (2 * hparams.sample_size)): src_word1, src_word2 = wsd_src_data[2 * block], wsd_src_data[2 * block + 1] for sent_id in range(block * hparams.sample_size, (block + 1) * hparams.sample_size): tgt_sent = src_word1.decode().encode( 'utf-8') + ' ' + eval_result[sent_id] eval_result_new.append(tgt_sent) return wsd_src_data, eval_result_new
def single_worker_inference( #emb_matrix, infer_model, ckpt, inference_input_file, inference_output_file, hparams, model_creator): """Inference with a single worker.""" output_infer = inference_output_file # Read data infer_data = load_data(inference_input_file, hparams) #saver = tf.train.Saver() with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size }) #sess.run(model_creator._build_decoder.eval()) # Decode #saver = tf.train.Saver() #emb=sess.run(emb_matrix) #fw=open('/home/yuzw/pun/nmt/inference/embedding_ds','w+') #fw.write('\n'.join( # [' '.join([str(u) for u in e]) for e in emb])) #print("emb=sess.run(emb_matrix)",emb) #save_path = saver.save(sess, "/home/yuzw/pun/nmt/inference/emb.npz") #print("Model saved in path: %s" % save_path) utils.print_out("# Start decoding single_worker_inference") if hparams.inference_indices: _decode_inference_indices( loaded_infer_model, sess, output_infer=output_infer, output_infer_summary_prefix=output_infer, inference_indices=hparams.inference_indices, tgt_eos=hparams.eos, subword_option=hparams.subword_option) else: nmt_utils.decode_and_evaluate( "infer", loaded_infer_model, sess, output_infer, ref_file=None, metrics=hparams.metrics, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, num_translations_per_input=hparams.num_translations_per_input)
def start_sess_and_load_model(infer_model, ckpt_path, hparams): """Start session and load model.""" print("num_intra_threads = %d, num_inter_threads = %d \n" % (hparams.num_intra_threads, hparams.num_inter_threads)) sess = tf.Session(graph=infer_model.graph, config=utils.get_config_proto( num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads)) with infer_model.graph.as_default(): loaded_infer_model = model_helper.load_model(infer_model.model, ckpt_path, sess, "infer") return sess, loaded_infer_model
def test(config, model_creator): # for metric in config.metrics.split(","): best_metric_label = "best_eval_loss" model_dir = getattr(config, best_metric_label + "_dir") logger.info("Start evaluating saved best model on training-set.") eval_model = model_helper.create_model(model_creator, config, mode="eval") session_config = utils.get_config_proto() eval_sess = tf.Session(config=session_config, graph=eval_model.graph) run_test(config, eval_model, eval_sess, config.train_file, model_dir) logger.info("Start evaluating saved best model on dev-set.") eval_model = model_helper.create_model(model_creator, config, mode="eval") session_config = utils.get_config_proto() eval_sess = tf.Session(config=session_config, graph=eval_model.graph) run_test(config, eval_model, eval_sess, config.dev_file, model_dir) logger.info("Start evaluating saved best model on test-set.") eval_model = model_helper.create_model(model_creator, config, mode="eval") session_config = utils.get_config_proto() eval_sess = tf.Session(config=session_config, graph=eval_model.graph) run_test(config, eval_model, eval_sess, config.test_file, model_dir)
def run_prediction(input_file_path, output_file_path): infile = 'input_file' word_split(input_file_path, infile, jieba_split) model_dir = 'jb_attention' hparams = utils.load_hparams(model_dir) hparams.inference_indices = [i for i in range(150)] sample_src_dataset = inference.load_data(infile) log_device_placement = hparams.log_device_placement if not hparams.attention: model_creator = nmt_model.Model else: if (hparams.encoder_type == 'gnmt' or hparams.attention_architecture in ['gnmt', 'gnmt_v2']): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == 'standard': model_creator = attention_model.AttentionModel else: raise ValueError('Unknown attention architecture %s' % (hparams.attention_architecture)) infer_model = model_helper.create_infer_model(model_creator, hparams, scope=None) config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) infer_sess = tf.Session(target='', config=config_proto, graph=infer_model.graph) with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, 'infer') iterator_feed_dict = { infer_model.src_placeholder: sample_src_dataset, infer_model.batch_size_placeholder: 1, } infer_sess.run(infer_model.iterator.initalizer, feed_dict=iterator_feed_dict) while True: try: nmt_outputs, _ = infer_model.decode(infer_sess) except tf.errors.OutOfRangeError: break
def __init__(self, hparams): self.hparams = hparams # Data locations self.out_dir = hparams.out_dir self.model_dir = os.path.join(self.out_dir, 'ckpts') if not tf.gfile.Exists(self.model_dir): tf.gfile.MakeDirs(self.model_dir) self.train_src_file = os.path.join( hparams.data_dir, hparams.train_prefix + '.' + hparams.src_suffix) self.train_tgt_file = os.path.join( hparams.data_dir, hparams.train_prefix + '.' + hparams.tgt_suffix) self.test_src_file = os.path.join( hparams.data_dir, hparams.test_prefix + '.' + hparams.src_suffix) self.test_tgt_file = os.path.join( hparams.data_dir, hparams.test_prefix + '.' + hparams.tgt_suffix) self.dev_src_file = os.path.join( hparams.data_dir, hparams.dev_prefix + '.' + hparams.src_suffix) self.dev_tgt_file = os.path.join( hparams.data_dir, hparams.dev_prefix + '.' + hparams.tgt_suffix) self.infer_out_file = os.path.join(self.out_dir, 'infer_output') self.eval_out_file = os.path.join(self.out_dir, 'eval_output') # Create models attention_option = hparams.attention_option if attention_option: model_creator = AttentionModel else: model_creator = BasicModel self.train_model = model_helper.create_train_model( hparams=hparams, model_creator=model_creator) self.eval_model = model_helper.create_eval_model( hparams=hparams, model_creator=model_creator) self.infer_model = model_helper.create_infer_model( hparams=hparams, model_creator=model_creator) # Sessions config_proto = utils.get_config_proto() self.train_sess = tf.Session(config=config_proto, graph=self.train_model.graph) self.eval_sess = tf.Session(config=config_proto, graph=self.eval_model.graph) self.infer_sess = tf.Session(config=config_proto, graph=self.infer_model.graph) # EOS self.tgt_eos = Vocabulary.EOS.encode("utf-8")
def infer_fn(hparams, identity, scope=None, extra_args=None, target_session=""): """main entry point for inference and evaluation.""" # create infer and eval models infer_model = model_helper.create_infer_model( diag_model.Model, hparams, scope, extra_args=extra_args) eval_model = model_helper.create_eval_model(diag_model.Model, hparams, scope) config_proto = utils.get_config_proto( log_device_placement=hparams.log_device_placement, allow_soft_placement=True) # create the eval session eval_sess = tf.Session( target=target_session, config=config_proto, graph=eval_model.graph) secondary_fn_tmp(hparams, identity, hparams.out_dir, infer_model, eval_model, eval_sess, "infer", single_worker_inference)
def inference(infer_data): with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: [infer_data], infer_model.batch_size_placeholder: hparams.infer_batch_size }) translation = decode_inference_indices(loaded_infer_model, sess) return translation
def inference(config, model_creator): output_file = "output_" + os.path.split(config.infer_file)[-1].split(".")[0] # Inference output directory pred_file = os.path.join(config.model_dir, output_file) utils.makedir(pred_file) # Inference model_dir = config.best_eval_loss_dir # Create model # model_creator = my_model.MyModel infer_model = model_helper.create_model(model_creator, config, mode="infer") # TensorFlow model sess_config = utils.get_config_proto() infer_sess = tf.Session(config=sess_config, graph=infer_model.graph) with infer_model.graph.as_default(): loaded_infer_model, _ = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer") run_infer(config, loaded_infer_model, infer_sess, pred_file)
def single_worker_inference(infer_model, ckpt, inference_input_file, inference_output_file, hparams): """Inference with a single worker.""" output_infer = inference_output_file # Read data infer_data = load_data(inference_input_file, hparams) with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size }) # Decode utils.print_out("# Start decoding") if hparams.inference_indices: _decode_inference_indices( loaded_infer_model, sess, output_infer=output_infer, output_infer_summary_prefix=output_infer, inference_indices=hparams.inference_indices, tgt_eos=hparams.eos, bpe_delimiter=hparams.bpe_delimiter) else: nmt_utils.decode_and_evaluate("infer", loaded_infer_model, sess, output_infer, ref_file=None, metrics=hparams.metrics, bpe_delimiter=hparams.bpe_delimiter, beam_width=hparams.beam_width, tgt_eos=hparams.eos)
def single_worker_inference(infer_model, ckpt, inference_input_file, inference_output_file, hparams): output_infer = inference_output_file infer_data = load_data(inference_input_file, hparams) with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_util.load_model(infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size }) utils.print_out("# Start decoding") if hparams.inference_indices: _decode_inference_indices( loaded_infer_model, sess, output_infer=output_infer, output_infer_summary_prefix=output_infer, inference_indices=hparams.inference_indices, tgt_eos=hparams.eos, subword_option=hparams.subword_option) else: nmt_utils.decode_and_evaluate( "infer", loaded_infer_model, sess, output_infer, ref_file=None, metrics=hparams.metrics, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, num_translations_per_input=hparams.num_translations_per_input)
def run_sample_decode_pungan_prepare(hparams, scope=None, target_session=""): log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir if not hparams.attention: # choose this model model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) def dealt(input, output): with open(input) as f: with open(output, 'w') as fw: for line in f: l = line.strip().split() l.reverse() sent = ' '.join(l) fw.write(sent + '\n') wsd_src_file = "%s" % (hparams.sample_prefix) wsd_src_file_new = wsd_src_file + '.new' dealt(wsd_src_file, wsd_src_file_new) wsd_src_data = inference.load_data(wsd_src_file_new) model_dir = hparams.out_dir # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) print('len wsd_src_data', len(wsd_src_data)) eval_result = [] for i in range(len(wsd_src_data) / 32): eval_result += run_sample_decode_pungan( infer_model, infer_sess, model_dir, hparams, wsd_src_data[i * 32:(i + 1) * 32]) print('eval_result') print(eval_result) print(len(eval_result)) backward_step1_in = [] with open(PUNGAN_ROOT_PATH + '/Pun_Generation/data/1backward/backward_step1.in') as f: for line in f: backward_step1_in.append(line.strip()) def wsd_input_format(wsd_src_data, eval_result): ''' test_data[0] {'target_word': u'art#n', 'target_sense': None, 'id': 'senseval2.d000.s000.t000', 'context': ['the', '<target>', 'of', 'change_ringing', 'be', 'peculiar', 'to', 'the', 'english', ',', 'and', ',', 'like', 'most', 'english', 'peculiarity', ',', 'unintelligible', 'to', 'the', 'rest', 'of', 'the', 'world', '.'], 'poss': ['DET', 'NOUN', 'ADP', 'NOUN', 'VERB', 'ADJ', 'PRT', 'DET', 'NOUN', '.', 'CONJ', '.', 'ADP', 'ADJ', 'ADJ', 'NOUN', '.', 'ADJ', 'PRT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', '.']} ''' wsd_input = [] senses_input = [] for i in range(len(eval_result)): block = i / 32 src_word1, src_word2 = backward_step1_in[ 2 * block], backward_step1_in[2 * block + 1] tgt_sent = wsd_src_data[i].decode().encode( 'utf-8') + ' ' + eval_result[i] tgt_word = src_word1 synset = wn.lemma_from_key(tgt_word).synset() s = synset.name() target_word = '#'.join(s.split('.')[:2]) context = tgt_sent.split(' ') for j in range(len(context)): if context[j] == tgt_word: context[j] = '<target>' poss_list = ['.' for _ in range(len(context))] tmp_dict = { 'target_word': target_word, 'target_sense': None, 'id': None, 'context': context, 'poss': poss_list } wsd_input.append(tmp_dict) senses_input.append((src_word1, src_word2)) return wsd_input, senses_input wsd_input, senses_input = wsd_input_format(wsd_src_data, eval_result) print('wsd_input', wsd_input) print("len of wsd_input", len(wsd_input)) return wsd_input, senses_input, wsd_src_data, eval_result
def train(hparams, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats avg_ckpts = hparams.avg_ckpts if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval # Create model model_creator = get_model_creator(hparams) train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats, info, start_train_time = before_train(loaded_train_model, train_model, train_sess, global_step, hparams, log_f) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, get_best_results(hparams), log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) add_info_summaries(summary_writer, global_step, info) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) (result_summary, _, final_eval_metrics) = (run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)) print_step_info("# Final, ", global_step, info, result_summary, log_f) utils.print_time("# Done training!", start_train_time) summary_writer.close() utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("# Best %s, " % metric, best_global_step, info, result_summary, log_f) summary_writer.close() if avg_ckpts: best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("# Averaged Best %s, " % metric, best_global_step, info, result_summary, log_f) summary_writer.close() return final_eval_metrics, global_step
def train(hparams, identity, scope=None, target_session=""): """main loop to train the dialogue model. identity is used.""" out_dir = hparams.out_dir steps_per_stats = hparams.steps_per_stats steps_per_internal_eval = 3 * steps_per_stats model_creator = diag_model.Model train_model = model_helper.create_train_model(model_creator, hparams, scope) model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, identity+"log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # load TensorFlow session and model config_proto = utils.get_config_proto( log_device_placement=hparams.log_device_placement, allow_soft_placement=True) train_sess = tf.Session( target=target_session, config=config_proto, graph=train_model.graph) train_handle = train_sess.run(train_model.train_iterator.string_handle()) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # initialize summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, "train_log"), train_model.graph) last_stats_step = global_step last_eval_step = global_step # initialize training stats. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # initialize iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run( train_model.train_iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) # main training loop while global_step < hparams.num_train_steps: start_time = time.time() try: # run a step step_result = loaded_train_model.train(train_sess, train_handle) (_, step_loss, all_summaries, step_predict_count, step_summary, global_step, step_word_count, batch_size, _, _, words1, words2, mask1, mask2) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # finished an epoch hparams.epoch_step = 0 utils.print_out("# Finished an epoch, step %d." % global_step) train_sess.run( train_model.train_iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) for key in all_summaries: utils.add_summary(summary_writer, global_step, key, all_summaries[key]) # update statistics step_time += (time.time() - start_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) if global_step - last_stats_step >= steps_per_stats: # print statistics for the previous epoch and save the model. last_stats_step = global_step avg_step_time = step_time / steps_per_stats utils.add_summary(summary_writer, global_step, "step_time", avg_step_time) train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) if math.isnan(train_ppl): break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 # save the model loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "dialogue.ckpt"), global_step=global_step) # print the dialogue if in debug mode if hparams.debug: utils.print_current_dialogue(words1, words2, mask1, mask2) # write out internal evaluation if global_step - last_eval_step >= steps_per_internal_eval: last_eval_step = global_step utils.print_out("# Internal Evaluation. global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # finished training loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "dialogue.ckpt"), global_step=global_step) result_summary = "" utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") summary_writer.close()
def train(hps, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hps.log_device_placement out_dir = hps.out_dir num_train_steps = hps.num_train_steps steps_per_stats = hps.steps_per_stats steps_per_external_eval = hps.steps_per_external_eval steps_per_eval = 100 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if hps.attention_architecture == "baseline": model_creator = AttentionModel else: model_creator = AttentionHistoryModel train_model = model_helper.create_train_model(model_creator, hps, scope) eval_model = model_helper.create_eval_model(model_creator, hps, scope) infer_model = model_helper.create_infer_model(model_creator, hps, scope) # Preload data for sample decoding. article_filenames = [] abstract_filenames = [] art_dir = hps.data_dir + '/article' abs_dir = hps.data_dir + '/abstract' for file in os.listdir(art_dir): if file.startswith(hps.dev_prefix): article_filenames.append(art_dir + "/" + file) for file in os.listdir(abs_dir): if file.startswith(hps.dev_prefix): abstract_filenames.append(abs_dir + "/" + file) # if random_decode: # """if this is a random sampling process during training""" decode_id = random.randint(0, len(article_filenames) - 1) single_article_file = article_filenames[decode_id] single_abstract_file = abstract_filenames[decode_id] dev_src_file = single_article_file dev_tgt_file = single_abstract_file sample_src_data = inference_base_model.load_data(dev_src_file) sample_tgt_data = inference_base_model.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hps.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation # run_full_eval( # model_dir, infer_model, infer_sess, # eval_model, eval_sess, hps, # summary_writer,sample_src_data,sample_tgt_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats = init_stats() speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval( session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = hps.batch_size * hps.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) epoch_step = 0 while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, infer_sess, model_dir, hps, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hps, summary_writer) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary and accumulate statistics global_step = update_stats(stats, summary_writer, start_time, step_result) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = check_stats(stats, global_step, steps_per_stats, hps, log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "summarized.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hps, summary_writer, sample_src_data, sample_tgt_data) dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hps, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "summarized.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hps, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hps, summary_writer) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "summarized.ckpt"), global_step=global_step) result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hps, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) summary_writer.close() utils.print_out("# Start evaluating saved best models.") for metric in hps.metrics: best_model_dir = getattr(hps, "best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hps, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Best %s, step %d " "step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def multi_worker_inference(infer_model, ckpt, inference_input_file, inference_output_file, hparams, num_workers, job_id): assert num_workers > 1 final_output_infer = inference_output_file output_infer = "%s_%d" % (inference_output_file, job_id) output_infer_done = "%s_done_%d" % (inference_output_file, job_id) infer_data = load_data(inference_input_file, hparams) total_load = len(infer_data) load_per_worker = int((total_load - 1) / num_workers) + 1 start_position = job_id * load_per_worker end_position = min(start_position + load_per_worker, total_load) infer_data = infer_data[start_position:end_position] with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_util.load_model(infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, feed_dict={ infer_model.src_placeholder: infer_data, infer_model.batch_size_placeholder: hparams.infer_batch_size }) utils.print_out("# Start decoding") nmt_utils.decode_and_evaluate( "infer", loaded_infer_model, sess, output_infer, ref_file=None, metrics=hparams.metrics, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, num_translations_per_input=hparams.num_translations_per_input) tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) if job_id != 0: return with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: for worker_id in range(num_workers): worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) while not tf.gfile.Exists(worker_infer_done): utils.print_out(" waitting job %d to complete." % worker_id) time.sleep(10) with codecs.getreader("utf-8")(tf.gfile.GFile( worker_infer_done, mode="rb")) as f: for translation in f: final_f.write("%s" % translation) for worker_id in range(num_workers): worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) tf.gfile.Remove(worker_infer_done)
def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""): """Run main.""" # Job jobid = flags.jobid num_workers = flags.num_workers utils.print_out("# Job id %d" % jobid) # Random random_seed = flags.random_seed if random_seed is not None and random_seed > 0: utils.print_out("# Set random seed to %d" % random_seed) random.seed(random_seed + jobid) np.random.seed(random_seed + jobid) # Model output directory out_dir = flags.out_dir if out_dir and not tf.gfile.Exists(out_dir): utils.print_out("# Creating output directory %s ..." % out_dir) tf.gfile.MakeDirs(out_dir) # Load hparams. loaded_hparams = False if flags.ckpt: # Try to load hparams from the same directory as ckpt ckpt_dir = os.path.dirname(flags.ckpt) ckpt_hparams_file = os.path.join(ckpt_dir, "hparams") if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path: hparams = create_or_load_hparams(ckpt_dir, default_hparams, flags.hparams_path, save_hparams=False) loaded_hparams = True if not loaded_hparams: # Try to load from out_dir assert out_dir hparams = create_or_load_hparams(out_dir, default_hparams, flags.hparams_path, save_hparams=(jobid == 0)) # GPU device config_proto = utils.get_config_proto( allow_soft_placement=True, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) utils.print_out("# Devices visible to TensorFlow: %s" % repr(tf.Session(config=config_proto).list_devices())) ## Train / Decode if flags.inference_input_file: # Inference output directory trans_file = flags.inference_output_file assert trans_file trans_dir = os.path.dirname(trans_file) if not tf.gfile.Exists(trans_dir): tf.gfile.MakeDirs(trans_dir) # Inference indices hparams.inference_indices = None if flags.inference_list: (hparams.inference_indices) = ([ int(token) for token in flags.inference_list.split(",") ]) # Inference ckpt = flags.ckpt if not ckpt: ckpt = tf.train.latest_checkpoint(out_dir) inference_fn(flags.run, flags.iterations, ckpt, flags.inference_input_file, trans_file, hparams, num_workers, jobid) # Evaluation if flags.run == 'accuracy': ref_file = flags.inference_ref_file if ref_file and tf.gfile.Exists(trans_file): for metric in hparams.metrics: score = evaluation_utils.evaluate(ref_file, trans_file, metric, hparams.subword_option) utils.print_out(" %s: %.1f" % (metric, score)) else: # Train train_fn(hparams, target_session=target_session)
def train(hparams, scope=None, target_session="", compute_ppl=0): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats avg_ckpts = hparams.avg_ckpts if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: # choose this model model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) wsd_src_file = "%s" % (hparams.sample_prefix) wsd_src_data = inference.load_data(wsd_src_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation ''' run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) ''' last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats, info, start_train_time = before_train(loaded_train_model, train_model, train_sess, global_step, hparams, log_f) end_step = global_step + 100 while global_step < end_step: # num_train_steps ### Run a step ### start_time = time.time() try: # then forward inference result to WSD, get reward step_result = loaded_train_model.train(train_sess) # forward reward to placeholder of loaded_train_model, and write a new train function where loss = loss*reward hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) # run_sample_decode(infer_model, infer_sess, model_dir, hparams, # summary_writer, sample_src_data, sample_tgt_data) # only for pretrain # run_external_eval(infer_model, infer_sess, model_dir, hparams, # summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result, hparams) summary_writer.add_summary(step_summary, global_step) if compute_ppl: run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, _get_best_results(hparams), log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", info["train_ppl"]) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) '''
def multi_worker_inference(infer_model, ckpt, inference_input_file, inference_output_file, hparams, num_workers, jobid): """Inference using multiple workers.""" assert num_workers > 1 final_output_infer = inference_output_file output_infer = "%s_%d" % (inference_output_file, jobid) output_infer_done = "%s_done_%d" % (inference_output_file, jobid) # Read data infer_data = load_data(inference_input_file, hparams) # Split data to multiple workers total_load = len(infer_data) load_per_worker = int((total_load - 1) / num_workers) + 1 start_position = jobid * load_per_worker end_position = min(start_position + load_per_worker, total_load) infer_data = infer_data[start_position:end_position] with tf.Session(graph=infer_model.graph, config=utils.get_config_proto()) as sess: loaded_infer_model = model_helper.load_model(infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, {infer_model.src_placeholder: infer_data}) # Decode utils.print_out("# Start decoding") nmt_utils.decode_and_evaluate("infer", loaded_infer_model, sess, output_infer, ref_file=None, metrics=hparams.metrics, bpe_delimiter=hparams.bpe_delimiter, beam_width=hparams.beam_width, tgt_eos=hparams.eos) # Change file name to indicate the file writing is completed. tf.gfile.Rename(output_infer, output_infer_done, overwrite=True) # Job 0 is responsible for the clean up. if jobid != 0: return # Now write all translations with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer, mode="wb")) as final_f: for worker_id in range(num_workers): worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id) while not tf.gfile.Exists(worker_infer_done): utils.print_out(" waitting job %d to complete." % worker_id) time.sleep(10) with codecs.getreader("utf-8")(tf.gfile.GFile( worker_infer_done, mode="rb")) as f: for translation in f: final_f.write("%s" % translation) tf.gfile.Remove(worker_infer_done)
def train(hparams): """Train a seq2seq model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats model_creator = model.Model train_model = model_helper.create_train_model(model_creator, hparams) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_files=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session(config=config_proto, graph=train_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) last_stats_step = global_step # This is the training loop. stats, info, start_train_time = before_train( loaded_train_model, train_model, train_sess, global_step, hparams, log_f) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out("# Finished an epoch, step %d." % global_step) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, log_f) if is_overflow: break # Reset statistics stats = init_stats() # Done training loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "seq2seq.ckpt"), global_step=global_step) summary_writer.close() return global_step
def multi_worker_selfplay(hparams, identity, scope=None, target_session='', is_chief=True, ps_tasks=0, num_workers=1, jobid=0, startup_delay_steps=0): """This is the multi worker selfplay, mostly used for self play distributed training. identity is used. """ immutable_model_reload_freq = hparams.immutable_model_reload_freq # 1. models and summary writer model_creator = diag_model.Model extra_args = model_helper.ExtraArgs( single_cell_fn=None, model_device_fn=tf.train.replica_device_setter(ps_tasks), attention_mechanism_fn=None) mutable_model = model_helper.create_selfplay_model(model_creator, is_mutable=True, num_workers=num_workers, jobid=jobid, hparams=hparams, scope=scope, extra_args=extra_args) immutable_hparams = copy.deepcopy(hparams) immutable_hparams.num_gpus = 0 immutable_model = model_helper.create_selfplay_model( model_creator, is_mutable=False, num_workers=num_workers, jobid=jobid, hparams=immutable_hparams, scope=scope) if hparams.self_play_immutable_gpu: print('using GPU for immutable') immutable_sess = tf.Session( graph=immutable_model.graph, config=tf.ConfigProto(allow_soft_placement=True)) else: print('not using GPU for immutable') immutable_sess = tf.Session(graph=immutable_model.graph, config=tf.ConfigProto( allow_soft_placement=True, device_count={'GPU': 0})) immutable_model, immutable_sess = load_self_play_model( immutable_model, immutable_sess, 'immutable', hparams.self_play_pretrain_dir, hparams.out_dir) global_step = immutable_model.model.global_step.eval( session=immutable_sess) if is_chief: ckpt = tf.train.latest_checkpoint(hparams.out_dir) if not ckpt: print('global_step, saving pretrain model to hparams.out_dir', global_step, hparams.out_dir) immutable_model.model.saver.save( # this is the prevent adam error immutable_sess, os.path.join(hparams.out_dir, 'dialogue.ckpt'), global_step=global_step) print('save finished') if is_chief: summary_writer_path = os.path.join( hparams.out_dir, identity + task_SP_DISTRIBUTED + '_log') summary_writer = tf.summary.FileWriter(summary_writer_path, mutable_model.graph) print('summary writer established at', summary_writer_path) else: summary_writer = None # 2. supervisor and sessions sv = tf.train.Supervisor( graph=mutable_model.graph, is_chief=is_chief, saver=mutable_model.model.saver, save_model_secs=0, # disable automatic save checkpoints summary_op=None, logdir=hparams.out_dir, checkpoint_basename='dialogue.ckpt') mutable_config = utils.get_config_proto( log_device_placement=hparams.log_device_placement, allow_soft_placement=True) mutable_config.device_count['GPU'] = hparams.num_gpus mutable_sess = sv.prepare_or_wait_for_session(target_session, config=mutable_config) # 3. additiona preparations global_step = mutable_model.model.global_step.eval(session=mutable_sess) while global_step < (jobid * (jobid + 1) * startup_delay_steps / 2): time.sleep(1) global_step = mutable_model.model.global_step.eval( session=mutable_sess) # save first model if is_chief: print('saveing the first checkpoint to', hparams.out_dir) mutable_model.model.saver.save(mutable_sess, os.path.join(hparams.out_dir, 'dialogue.ckpt'), global_step=global_step) last_save_step = global_step # Read data selfplay_data = dialogue_utils.load_data(hparams.self_play_train_data) selfplay_kb = dialogue_utils.load_data(hparams.self_play_train_kb) dialogue = SelfplayDialogue(mutable_model, immutable_model, mutable_sess, immutable_sess, hparams.max_dialogue_turns, hparams.train_threadhold, hparams.start_of_turn1, hparams.start_of_turn2, hparams.end_of_dialogue, summary_writer=summary_writer, dialogue_mode=task_SP_DISTRIBUTED, hparams=hparams) # 4. main loop last_immmutable_model_reload = global_step last_save_step = global_step batch_size = dialogue.batch_size assert batch_size <= len(selfplay_data) # this is the start point of the self-play data. force shuffling at the beginning i = len(selfplay_data) train_stats = [0, 0] while global_step < hparams.num_self_play_train_steps: # a. reload immutable model, muttable will be automated managed by supervisor if immutable_model_reload_freq > 0 and global_step - last_immmutable_model_reload > immutable_model_reload_freq: immutable_model, immutable_sess = load_self_play_model( immutable_model, immutable_sess, 'immutable', hparams.self_play_pretrain_dir, hparams.out_dir) last_immmutable_model_reload = global_step # b. possiblely flip between speakers (or roll out models), # based on either a random policy or by step counts agent1, agent2, mutable_agent_index = dialogue.flip_agent( (mutable_model, mutable_sess, dialogue.mutable_handles), (immutable_model, immutable_sess, dialogue.immutable_handles)) train_stats[mutable_agent_index] += 1 # read selfplay data start_time = time.time() if i * batch_size + batch_size > len(selfplay_data): # reacehd the end input_data = zip(selfplay_data, selfplay_kb) random.shuffle(input_data) # random shuffle input data i = 0 selfplay_data, selfplay_kb = zip(*input_data) start_ind, end_ind = i * batch_size, i * batch_size + batch_size batch_data, batch_kb = selfplay_data[start_ind:end_ind], selfplay_kb[ start_ind:end_ind] train_example, _, _ = dialogue.talk(hparams.max_dialogue_len, batch_data, batch_kb, agent1, agent2, batch_size, global_step) possible_global_step = dialogue.maybe_train(train_example, mutable_agent_index, global_step, force=True) if possible_global_step: global_step = possible_global_step if is_chief and global_step - last_save_step > hparams.self_play_dist_save_freq: mutable_model.model.saver.save(mutable_sess, os.path.join( hparams.out_dir, 'dialogue.ckpt'), global_step=global_step) last_save_step = global_step end_time = time.time() if is_chief: utils.add_summary(summary_writer, global_step, task_SP_DISTRIBUTED + '_' + 'time', end_time - start_time) utils.add_summary(summary_writer, global_step, task_SP_DISTRIBUTED + '_' + 'train_ratio', train_stats[0] * 1.0 / (train_stats[1] + 0.1)) i += 1 if is_chief: summary_writer.close() mutable_sess.close() immutable_sess.close()
def eval_fn(hparams, scope=None, target_session=""): """Evaluate a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats avg_ckpts = hparams.avg_ckpts if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) # First evaluation ckpt_size = len( tf.train.get_checkpoint_state(model_dir).all_model_checkpoint_paths) for ckpt_index in range(ckpt_size): train.run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, None, sample_src_data, sample_tgt_data, avg_ckpts, ckpt_index=ckpt_index)
def train(hparams, scope=None, target_session=''): """Train the chatbot""" # Initialize some local hyperparameters log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if hparams.architecture == "simple": model_creator = SimpleModel get_infer_iterator = iterator_utils.get_infer_iterator get_iterator = iterator_utils.get_iterator elif hparams.architecture == "hier": model_creator = HierarchicalModel # Parse some of the arguments now def curry_get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos, src_max_len): return end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos, src_max_len=src_max_len, eou=hparams.eou, dialogue_max_len=hparams.dialogue_max_len) get_infer_iterator = curry_get_infer_iterator def curry_get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos, src_reverse, random_seed, num_buckets, src_max_len=None, tgt_max_len=None, num_threads=4, output_buffer_size=None, skip_count=None): return end2end_iterator_utils.get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos, eou=hparams.eou, src_reverse=src_reverse, random_seed=random_seed, num_dialogue_buckets=num_buckets, src_max_len=src_max_len, tgt_max_len=tgt_max_len, num_threads=num_threads, output_buffer_size=output_buffer_size, skip_count=skip_count) get_iterator = curry_get_iterator else: raise ValueError("Unkown architecture", hparams.architecture) # Create three models which share parameters through the use of checkpoints train_model = create_train_model(model_creator, get_iterator, hparams, scope) eval_model = create_eval_model(model_creator, get_iterator, hparams, scope) infer_model = inference.create_infer_model(model_creator, get_infer_iterator, hparams, scope) # ToDo: adapt for architectures # Preload the data to use for sample decoding dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # Create the configurations for the sessions config_proto = utils.get_config_proto(log_device_placement=log_device_placement) # Create three sessions, one for each model train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) # Load the train model from checkpoint or create a new one with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model(train_model.model, model_dir, train_sess, name="train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. # Initialize the hyperparameters for the loop. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # epoch_step records where we were within an epoch. Used to skip trained on examples skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) # Initialize the training iterator train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) # Train until we reach num_steps. while global_step < num_train_steps: # Run a step start_step_time = time.time() try: step_result = loaded_train_model.train(train_sess) (_, step_loss, step_predict_count, step_summary, global_step, # The _ is the output of the update op step_word_count, batch_size) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) # Decode and print a random sentence run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Perform external evaluation to save checkpoints if this is the best for some metric dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams, summary_writer, save_on_best_dev=True) # Reinitialize the iterator from the beginning train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) # update statistics step_time += (time.time() - start_step_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step # Print statistics for the previous epoch. avg_step_time = step_time / steps_per_stats train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) utils.print_out( " global step %d lr %g " "step-time %.2fs wps %.2fK ppl %.2f %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, _get_best_results(hparams)), log_f) if math.isnan(train_ppl): # The model has screwed up break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 if global_step - last_eval_step >= steps_per_eval: # Perform evaluation. Start by reassigning the last_eval_step variable to the current step last_eval_step = global_step # Print the progress and add summary utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) # Decode and print a random sample run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Run internal evaluation, and update the ppl variables. The data iterator is instantieted in the method. dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: # Run the external evaluation last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) # Decode and print a random sample run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Run external evaluation, updating metric scores in the meanwhile. The unneeded output is the global step. dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams, summary_writer, save_on_best_dev=True) # Done training. Save the model loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out("# Best %s, step %d " "step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)