def run_external_evaluation(infer_model, infer_sess, model_dir, hparams, summary_writer, save_on_best_dev): with infer_model.graph.as_default(): # Load the model from checkpoint. It automatically loads the latest checkpoint loaded_infer_model, global_step = model_helper.create_or_load_model( model=infer_model.model, model_dir=model_dir, session=infer_sess, name="infer" ) # Fill the feed_dict for the evaluation dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) inference_dev_data = inference.load_data(dev_src_file) dev_infer_iterator_feed_dict = { infer_model.src_placeholder: inference_dev_data, infer_model.batch_size_placeholder: hparams.infer_batch_size } dev_scores = _external_eval( model=loaded_infer_model, global_step=global_step, sess=infer_sess, hparams=hparams, iterator=infer_model.iterator, iterator_feed_dict=dev_infer_iterator_feed_dict, tgt_file=dev_tgt_file, label="dev", summary_writer=summary_writer, save_on_best_dev=save_on_best_dev ) test_scores = None if hparams.test_prefix: # Create the test data test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) inference_test_data = inference.load_data(test_src_file) test_infer_iterator_feed_dict = { infer_model.src_placeholder: inference_test_data, infer_model.batch_size_placeholder: hparams.infer_batch_size } # Run evaluation on the test dataset test_scores = _external_eval( model=loaded_infer_model, global_step=global_step, sess=infer_sess, hparams=hparams, iterator=infer_model.iterator, iterator_feed_dict=test_infer_iterator_feed_dict, tgt_file=test_tgt_file, label="test", summary_writer=summary_writer, save_on_best_dev=False # We do not use the test set at all in training as that means overfitting ) return dev_scores, test_scores, global_step
def run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, save_best_dev=True, use_test_set=True, avg_ckpts=False, ckpt_index=None): """Compute external evaluation (bleu, rouge, etc.) for both dev / test.""" with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer", ckpt_index) dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) dev_infer_iterator_feed_dict = { infer_model.src_placeholder: inference.load_data(dev_src_file), infer_model.batch_size_placeholder: hparams.infer_batch_size, } dev_scores = _external_eval(loaded_infer_model, global_step, infer_sess, hparams, infer_model.iterator, dev_infer_iterator_feed_dict, dev_tgt_file, "dev", summary_writer, save_on_best=save_best_dev, avg_ckpts=avg_ckpts) test_scores = None if use_test_set and hparams.test_prefix: test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) test_infer_iterator_feed_dict = { infer_model.src_placeholder: inference.load_data(test_src_file), infer_model.batch_size_placeholder: hparams.infer_batch_size, } test_scores = _external_eval(loaded_infer_model, global_step, infer_sess, hparams, infer_model.iterator, test_infer_iterator_feed_dict, test_tgt_file, "test", summary_writer, save_on_best=False, avg_ckpts=avg_ckpts) return dev_scores, test_scores, global_step
def run_sample_decode_pungan_prepare(hparams, scope=None, target_session=""): log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir if not hparams.attention: # choose this model model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) wsd_src_file = "%s" % (hparams.sample_prefix) wsd_src_data = inference.load_data(wsd_src_file) model_dir = hparams.out_dir # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) eval_result = run_sample_decode_pungan(infer_model, infer_sess, model_dir, hparams, wsd_src_data) print('eval_result', eval_result) print('eval_result.len', len(eval_result)) ''' for i in range(0,256): if i%32 == 0: print('\n') print('eval_result[',i,']',eval_result[i]) ''' print('wsd_src_data', wsd_src_data) print('wsd_src_data.len', len(wsd_src_data)) #len=16 wsd_src_data [u'problem%1:10:00::', u'problem%1:26:00::', u'drive%1:04:00::', u'drive%1:04:03::', u'identity%1:07:00::', u'identity%1:24:01::', u'point%1:10:01::', u'point%1:06:00::', u'tension%1:26:01::', u'tension%1:26:03::', u'log%2:32:00::', u'log%2:35:00::', u'fan%1:06:00::', u'fan%1:18:00::', u'file%2:32:00::', u'file%2:35:00::'] eval_result_new = [] for block in range(len(eval_result) / (2 * hparams.sample_size)): src_word1, src_word2 = wsd_src_data[2 * block], wsd_src_data[2 * block + 1] for sent_id in range(block * hparams.sample_size, (block + 1) * hparams.sample_size): tgt_sent = src_word1.decode().encode( 'utf-8') + ' ' + eval_result[sent_id] eval_result_new.append(tgt_sent) return wsd_src_data, eval_result_new
def eval_fn(hparams, scope=None, target_session=""): """Evaluate a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats avg_ckpts = hparams.avg_ckpts if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) # First evaluation ckpt_size = len( tf.train.get_checkpoint_state(model_dir).all_model_checkpoint_paths) for ckpt_index in range(ckpt_size): train.run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, None, sample_src_data, sample_tgt_data, avg_ckpts, ckpt_index=ckpt_index)
def train(hparams, scope=None, target_session="", compute_ppl=0): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats avg_ckpts = hparams.avg_ckpts if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: # choose this model model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) wsd_src_file = "%s" % (hparams.sample_prefix) wsd_src_data = inference.load_data(wsd_src_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation ''' run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) ''' last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats, info, start_train_time = before_train(loaded_train_model, train_model, train_sess, global_step, hparams, log_f) end_step = global_step + 100 while global_step < end_step: # num_train_steps ### Run a step ### start_time = time.time() try: # then forward inference result to WSD, get reward step_result = loaded_train_model.train(train_sess) # forward reward to placeholder of loaded_train_model, and write a new train function where loss = loss*reward hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) # run_sample_decode(infer_model, infer_sess, model_dir, hparams, # summary_writer, sample_src_data, sample_tgt_data) # only for pretrain # run_external_eval(infer_model, infer_sess, model_dir, hparams, # summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result, hparams) summary_writer.add_summary(step_summary, global_step) if compute_ppl: run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, _get_best_results(hparams), log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", info["train_ppl"]) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) '''
def run_sample_decode_pungan_prepare(hparams, scope=None, target_session=""): log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir if not hparams.attention: # choose this model model_creator = nmt_model.Model else: # Attention if (hparams.encoder_type == "gnmt" or hparams.attention_architecture in ["gnmt", "gnmt_v2"]): model_creator = gnmt_model.GNMTModel elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) def dealt(input, output): with open(input) as f: with open(output, 'w') as fw: for line in f: l = line.strip().split() l.reverse() sent = ' '.join(l) fw.write(sent + '\n') wsd_src_file = "%s" % (hparams.sample_prefix) wsd_src_file_new = wsd_src_file + '.new' dealt(wsd_src_file, wsd_src_file_new) wsd_src_data = inference.load_data(wsd_src_file_new) model_dir = hparams.out_dir # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) print('len wsd_src_data', len(wsd_src_data)) eval_result = [] for i in range(len(wsd_src_data) / 32): eval_result += run_sample_decode_pungan( infer_model, infer_sess, model_dir, hparams, wsd_src_data[i * 32:(i + 1) * 32]) print('eval_result') print(eval_result) print(len(eval_result)) backward_step1_in = [] with open(PUNGAN_ROOT_PATH + '/Pun_Generation/data/1backward/backward_step1.in') as f: for line in f: backward_step1_in.append(line.strip()) def wsd_input_format(wsd_src_data, eval_result): ''' test_data[0] {'target_word': u'art#n', 'target_sense': None, 'id': 'senseval2.d000.s000.t000', 'context': ['the', '<target>', 'of', 'change_ringing', 'be', 'peculiar', 'to', 'the', 'english', ',', 'and', ',', 'like', 'most', 'english', 'peculiarity', ',', 'unintelligible', 'to', 'the', 'rest', 'of', 'the', 'world', '.'], 'poss': ['DET', 'NOUN', 'ADP', 'NOUN', 'VERB', 'ADJ', 'PRT', 'DET', 'NOUN', '.', 'CONJ', '.', 'ADP', 'ADJ', 'ADJ', 'NOUN', '.', 'ADJ', 'PRT', 'DET', 'NOUN', 'ADP', 'DET', 'NOUN', '.']} ''' wsd_input = [] senses_input = [] for i in range(len(eval_result)): block = i / 32 src_word1, src_word2 = backward_step1_in[ 2 * block], backward_step1_in[2 * block + 1] tgt_sent = wsd_src_data[i].decode().encode( 'utf-8') + ' ' + eval_result[i] tgt_word = src_word1 synset = wn.lemma_from_key(tgt_word).synset() s = synset.name() target_word = '#'.join(s.split('.')[:2]) context = tgt_sent.split(' ') for j in range(len(context)): if context[j] == tgt_word: context[j] = '<target>' poss_list = ['.' for _ in range(len(context))] tmp_dict = { 'target_word': target_word, 'target_sense': None, 'id': None, 'context': context, 'poss': poss_list } wsd_input.append(tmp_dict) senses_input.append((src_word1, src_word2)) return wsd_input, senses_input wsd_input, senses_input = wsd_input_format(wsd_src_data, eval_result) print('wsd_input', wsd_input) print("len of wsd_input", len(wsd_input)) return wsd_input, senses_input, wsd_src_data, eval_result
def train(hparams, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats avg_ckpts = hparams.avg_ckpts if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval # Create model model_creator = get_model_creator(hparams) train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats, info, start_train_time = before_train(loaded_train_model, train_model, train_sess, global_step, hparams, log_f) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Process step_result, accumulate stats, and write summary global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, get_best_results(hparams), log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) add_info_summaries(summary_writer, global_step, info) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) (result_summary, _, final_eval_metrics) = (run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)) print_step_info("# Final, ", global_step, info, result_summary, log_f) utils.print_time("# Done training!", start_train_time) summary_writer.close() utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("# Best %s, " % metric, best_global_step, info, result_summary, log_f) summary_writer.close() if avg_ckpts: best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("# Averaged Best %s, " % metric, best_global_step, info, result_summary, log_f) summary_writer.close() return final_eval_metrics, global_step
def run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, save_best_dev=True, use_test_set=True, avg_ckpts=False, dev_infer_iterator_feed_dict=None, test_infer_iterator_feed_dict=None): """Compute external evaluation for both dev / test. Computes development and testing external evaluation (e.g. bleu, rouge) for given model. Args: infer_model: Inference model for which to compute perplexities. infer_sess: Inference TensorFlow session. model_dir: Directory from which to load inference model from. hparams: Model hyper-parameters. summary_writer: Summary writer for logging metrics to TensorBoard. use_test_set: Computes testing external evaluation if true; does not otherwise. Note that the development external evaluation is always computed regardless of value of this parameter. dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. Can be used to pass in additional inputs necessary for running the development external evaluation. test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session. Can be used to pass in additional inputs necessary for running the testing external evaluation. Returns: Triple containing development scores, testing scores and the TensorFlow Variable for the global step number, in this order. """ if dev_infer_iterator_feed_dict is None: dev_infer_iterator_feed_dict = {} if test_infer_iterator_feed_dict is None: test_infer_iterator_feed_dict = {} with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer") dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) dev_infer_iterator_feed_dict[ infer_model.src_placeholder] = inference.load_data(dev_src_file) dev_infer_iterator_feed_dict[ infer_model.batch_size_placeholder] = hparams.infer_batch_size dev_scores = _external_eval(loaded_infer_model, global_step, infer_sess, hparams, infer_model.iterator, dev_infer_iterator_feed_dict, dev_tgt_file, "dev", summary_writer, save_on_best=save_best_dev, avg_ckpts=avg_ckpts) test_scores = None if use_test_set and hparams.test_prefix: test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) test_infer_iterator_feed_dict[ infer_model.src_placeholder] = inference.load_data(test_src_file) test_infer_iterator_feed_dict[ infer_model.batch_size_placeholder] = hparams.infer_batch_size test_scores = _external_eval(loaded_infer_model, global_step, infer_sess, hparams, infer_model.iterator, test_infer_iterator_feed_dict, test_tgt_file, "test", summary_writer, save_on_best=False, avg_ckpts=avg_ckpts) return dev_scores, test_scores, global_step
def train(hparams, scope=None, target_session=''): """Train the chatbot""" # Initialize some local hyperparameters log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if hparams.architecture == "simple": model_creator = SimpleModel get_infer_iterator = iterator_utils.get_infer_iterator get_iterator = iterator_utils.get_iterator elif hparams.architecture == "hier": model_creator = HierarchicalModel # Parse some of the arguments now def curry_get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos, src_max_len): return end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos, src_max_len=src_max_len, eou=hparams.eou, dialogue_max_len=hparams.dialogue_max_len) get_infer_iterator = curry_get_infer_iterator def curry_get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos, src_reverse, random_seed, num_buckets, src_max_len=None, tgt_max_len=None, num_threads=4, output_buffer_size=None, skip_count=None): return end2end_iterator_utils.get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos, eou=hparams.eou, src_reverse=src_reverse, random_seed=random_seed, num_dialogue_buckets=num_buckets, src_max_len=src_max_len, tgt_max_len=tgt_max_len, num_threads=num_threads, output_buffer_size=output_buffer_size, skip_count=skip_count) get_iterator = curry_get_iterator else: raise ValueError("Unkown architecture", hparams.architecture) # Create three models which share parameters through the use of checkpoints train_model = create_train_model(model_creator, get_iterator, hparams, scope) eval_model = create_eval_model(model_creator, get_iterator, hparams, scope) infer_model = inference.create_infer_model(model_creator, get_infer_iterator, hparams, scope) # ToDo: adapt for architectures # Preload the data to use for sample decoding dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # Create the configurations for the sessions config_proto = utils.get_config_proto(log_device_placement=log_device_placement) # Create three sessions, one for each model train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) # Load the train model from checkpoint or create a new one with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model(train_model.model, model_dir, train_sess, name="train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. # Initialize the hyperparameters for the loop. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # epoch_step records where we were within an epoch. Used to skip trained on examples skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) # Initialize the training iterator train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) # Train until we reach num_steps. while global_step < num_train_steps: # Run a step start_step_time = time.time() try: step_result = loaded_train_model.train(train_sess) (_, step_loss, step_predict_count, step_summary, global_step, # The _ is the output of the update op step_word_count, batch_size) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) # Decode and print a random sentence run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Perform external evaluation to save checkpoints if this is the best for some metric dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams, summary_writer, save_on_best_dev=True) # Reinitialize the iterator from the beginning train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) # update statistics step_time += (time.time() - start_step_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step # Print statistics for the previous epoch. avg_step_time = step_time / steps_per_stats train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) utils.print_out( " global step %d lr %g " "step-time %.2fs wps %.2fK ppl %.2f %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, _get_best_results(hparams)), log_f) if math.isnan(train_ppl): # The model has screwed up break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 if global_step - last_eval_step >= steps_per_eval: # Perform evaluation. Start by reassigning the last_eval_step variable to the current step last_eval_step = global_step # Print the progress and add summary utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) # Decode and print a random sample run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Run internal evaluation, and update the ppl variables. The data iterator is instantieted in the method. dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: # Run the external evaluation last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) # Decode and print a random sample run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Run external evaluation, updating metric scores in the meanwhile. The unneeded output is the global step. dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams, summary_writer, save_on_best_dev=True) # Done training. Save the model loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out("# Best %s, step %d " "step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def train(hparams, scope=None, target_session="", single_cell_fn=None): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats if hparams.eval_on_fly: steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 2 * steps_per_eval else: steps_per_snapshot = hparams.snapshot_interval if not hparams.attention: model_creator = nmt_model.Model elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: model_creator = gnmt_model.GNMTModel else: raise ValueError("Unknown model architecture") train_model = create_train_model(model_creator, hparams, scope, single_cell_fn) if hparams.eval_on_fly: eval_model = create_eval_model(model_creator, hparams, scope, single_cell_fn) infer_model = inference.create_infer_model(model_creator, hparams, scope, single_cell_fn) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) if hparams.eval_on_fly: eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): model_helper.initialize_cnn(train_model.model, train_sess) loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation if hparams.eval_on_fly: run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) last_stats_step = global_step if hparams.eval_on_fly: last_eval_step = global_step last_external_eval_step = global_step else: last_snapshot_step = global_step # This is the training loop. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval( session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) (_, step_loss, step_predict_count, step_summary, global_step, step_word_count, batch_size) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) if hparams.eval_on_fly: run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hparams, summary_writer) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) # update statistics step_time += (time.time() - start_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step # Print statistics for the previous epoch. avg_step_time = step_time / steps_per_stats train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) utils.print_out( " global step %d lr %g step-time %.2fs wps %.2fK ppl %.2f %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, _get_best_results(hparams)), log_f) if math.isnan(train_ppl): break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 ## if (not hparams.eval_on_fly) and (global_step - last_snapshot_step >= steps_per_snapshot): last_snapshot_step = global_step utils.print_out("# Cihan: Saving Snapshot, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) if hparams.eval_on_fly and (global_step - last_eval_step >= steps_per_eval): last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if hparams.eval_on_fly and (global_step - last_external_eval_step >= steps_per_external_eval): last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hparams, summary_writer) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) if hparams.eval_on_fly: result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Best %s, step %d step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() if hparams.eval_on_fly: return dev_scores, test_scores, dev_ppl, test_ppl, global_step else: return global_step
def train(hparams, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: model_creator = nmt_model.Model elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: model_creator = gnmt_model.GNMTModel else: raise ValueError("Unknown model architecture") train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="w") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval( session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) if hparams.curriculum == 'none': train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) else: if hparams.curriculum == 'predictive_gain': exp3s = Exp3S(hparams.num_curriculum_buckets, 0.001, 0, 0.05) elif hparams.curriculum == 'look_back_and_forward': curriculum_point = 0 handle = train_model.iterator.handle for i in range(hparams.num_curriculum_buckets): train_sess.run( train_model.iterator.initializer[i].initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) iterator_handles = [ train_sess.run( train_model.iterator.initializer[i].string_handle(), feed_dict={train_model.skip_count_placeholder: skip_count}) for i in range(hparams.num_curriculum_buckets) ] utils.print_out("Starting training") while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: if hparams.curriculum != 'none': if hparams.curriculum == 'predictive_gain': lesson = exp3s.draw_task() elif hparams.curriculum == 'look_back_and_forward': if curriculum_point == hparams.num_curriculum_buckets: lesson = np.random.randint( low=0, high=hparams.num_curriculum_buckets) else: lesson = curriculum_point if np.random.random_sample( ) < 0.8 else np.random.randint( low=0, high=hparams.num_curriculum_buckets) step_result = loaded_train_model.train( hparams, train_sess, handle=handle, iterator_handle=iterator_handles[lesson], use_fed_source_placeholder=loaded_train_model. use_fed_source, fed_source_placeholder=loaded_train_model.fed_source) (_, step_loss, step_predict_count, step_summary, global_step, step_word_count, batch_size, source) = step_result if hparams.curriculum == 'predictive_gain': new_loss = train_sess.run( [loaded_train_model.train_loss], feed_dict={ handle: iterator_handles[lesson], loaded_train_model.use_fed_source: True, loaded_train_model.fed_source: source }) # new_loss = loaded_train_model.train_loss.eval( # session=train_sess, # feed_dict={ # handle: iterator_handles[lesson], # loaded_train_model.use_fed_source: True, # loaded_train_model.fed_source: source # }) # utils.print_out("lesson: %s, step loss: %s, new_loss: %s" % (lesson, step_loss, new_loss)) # utils.print_out("exp3s dist: %s" % (exp3s.pi, )) curriculum_point_a = lesson * ( hparams.src_max_len // hparams.num_curriculum_buckets) + 1 curriculum_point_b = ( lesson + 1) * (hparams.src_max_len // hparams.num_curriculum_buckets) + 1 v = step_loss - new_loss exp3s.update_w( v, float(curriculum_point_a + curriculum_point_b) / 2.0) elif hparams.curriculum == 'look_back_and_forward': utils.print_out("step loss: %s, lesson: %s" % (step_loss, lesson)) curriculum_point_a = curriculum_point * ( hparams.src_max_len // hparams.num_curriculum_buckets) + 1 curriculum_point_b = (curriculum_point + 1) * ( hparams.src_max_len // hparams.num_curriculum_buckets) + 1 if step_loss < (hparams.curriculum_progress_loss * (float(curriculum_point_a + curriculum_point_b) / 2.0)): curriculum_point += 1 else: step_result = loaded_train_model.train(hparams, train_sess) (_, step_loss, step_predict_count, step_summary, global_step, step_word_count, batch_size) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 # utils.print_out( # "# Finished an epoch, step %d. Perform external evaluation" % # global_step) # run_sample_decode(infer_model, infer_sess, # model_dir, hparams, summary_writer, sample_src_data, # sample_tgt_data) # dev_scores, test_scores, _ = run_external_eval( # infer_model, infer_sess, model_dir, # hparams, summary_writer) if hparams.curriculum == 'none': train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) else: train_sess.run( train_model.iterator.initializer[lesson].initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) # update statistics step_time += (time.time() - start_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: if hparams.curriculum == 'predictive_gain': utils.print_out("lesson: %s, step loss: %s, new_loss: %s" % (lesson, step_loss, new_loss)) utils.print_out("exp3s dist: %s" % (exp3s.pi, )) last_stats_step = global_step # Print statistics for the previous epoch. avg_step_time = step_time / steps_per_stats train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) utils.print_out( " global step %d lr %g " "step-time %.2fs wps %.2fK ppl %.2f %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, _get_best_results(hparams)), log_f) if math.isnan(train_ppl): break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hparams, summary_writer) # if global_step - last_external_eval_step >= steps_per_external_eval: # last_external_eval_step = global_step # # Save checkpoint # loaded_train_model.saver.save( # train_sess, # os.path.join(out_dir, "translate.ckpt"), # global_step=global_step) # run_sample_decode(infer_model, infer_sess, # model_dir, hparams, summary_writer, sample_src_data, # sample_tgt_data) # dev_scores, test_scores, _ = run_external_eval( # infer_model, infer_sess, model_dir, # hparams, summary_writer) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Best %s, step %d " "step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
result_path = "rm_static_nearby" result_path = osp.join(root_path, result_path) if not osp.exists(result_path): os.makedirs(result_path) if __name__ == "__main__": pred_labels_train = np.load( osp.join(result_path, "pred_labels_train_pnt2.npy")) pred_labels_val = np.load(osp.join(result_path, "pred_labels_val_pnt2.npy")) true_labels_train = np.load( osp.join(result_path, "true_labels_train_pnt2.npy")) true_labels_val = np.load(osp.join(result_path, "true_labels_val_pnt2.npy")) features, labels = load_data(data_path) features = features[:, :, :3] img_path = osp.join(result_path, "pointnet_result") color_map = ['y', 'g', 'r', 'b'] if not osp.exists(img_path): print("Creating image folder...") os.makedirs(img_path) pred_labels = np.concatenate((pred_labels_train, pred_labels_val), axis=0) print(pred_labels.shape) true_labels = np.concatenate((true_labels_train, true_labels_val), axis=0) print(true_labels.shape) pbar = tqdm.tqdm(total=features.shape[0]) for i in range(features.shape[0]): pbar.update()
def train(config): out_dir = config.out_dir if not config.attention: model_creator = models.VanillaModel else: model_creator = models.AttentionModel train_model = model_base.build_model_graph(model_creator, config, "train") eval_model = model_base.build_model_graph(model_creator, config, "eval") test_model = inference.build_inference_graph(model_creator, config) test_src = inference.load_data(test_model.src_file) test_tgt = inference.load_data(test_model.tgt_file) train_sess = tf.Session(graph=train_model.graph) eval_sess = tf.Session(graph=eval_model.graph) test_sess = tf.Session(graph=test_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_base.create_or_load_model( train_model.model, out_dir, train_sess, "train") summary_writer = tf.summary.FileWriter( os.path.join(out_dir, "train_log"), train_model.graph) train_sess.run(train_model.iterator.initializer) while global_step < config.num_train_steps: try: step_result = loaded_train_model.train(train_sess, debug=True) _, loss, predict_count, global_step, summary, src, src_len, tgt, tgt_len, preds = step_result print loss print src print src_len print tgt print tgt_len print preds print except tf.errors.OutOfRangeError: # epoch is done train_sess.run(train_model.iterator.initializer) summary_writer.add_summary(summary, global_step) print global_step, config.steps_per_eval if global_step % config.steps_per_eval == 0: # save and evaluate loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "model.ckpt"), global_step=global_step) # with train_model.graph.as_default(): # variables_names = [v.name for v in tf.trainable_variables()] # values = train_sess.run(variables_names) # for k, v in zip(variables_names, values): # print k # print run_sample_decode( test_model, test_sess, out_dir, config, summary_writer, test_src, test_tgt) eval_loss = run_eval( eval_model, eval_sess, out_dir, config, summary_writer) print 'EVAL LOSS ', eval_loss
def train(hparams, scope=None, target_session=""): log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats avg_ckpts = hparams.avg_ckpts if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if not hparams.attention: model_creator = nmt_model.Model else: if hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel else: raise ValueError("Unknown attention architecture %s" % hparams.attention_architecture) train_model = model_util.create_train_model(model_creator, hparams, scope) eval_model = model_util.create_eval_model(model_creator, hparams, scope) infer_model = model_util.create_infer_model(model_creator, hparams, scope) dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) config_proto = utils.get_config_proto( log_device_placement=log_device_placement, num_intra_threads=hparams.num_intra_threads, num_inter_threads=hparams.num_inter_threads) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_util.create_or_load_model( train_model.model, model_dir, train_sess, "train") summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step stats, info, start_train_time = before_train(loaded_train_model, train_model, train_sess, global_step, hparams, log_f) while global_step < num_train_steps: start_time = time.time() try: step_result = loaded_train_model.train(train_sess) hparams.epoch_step += 1 except tf.errors.OutOfRangeError: hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue global_step, info["learning_rate"], step_summary = update_stats( stats, start_time, step_result) summary_writer.add_summary(step_summary, global_step) if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = process_stats(stats, info, global_step, steps_per_stats, log_f) print_step_info(" ", global_step, info, _get_best_results(hparams), log_f) if is_overflow: break stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_perplexity", info["train_perplexity"]) loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) run_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer) if avg_ckpts: run_avg_external_eval(infer_model, infer_sess, model_dir, hparams, summary_writer, global_step) loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) result_summary, _, final_eval_metrics = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data, avg_ckpts) print_step_info("# Final, ", global_step, info, result_summary, log_f) utils.print_time("# Done training!", start_train_time) summary_writer.close() utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("# Best %s, " % metric, best_global_step, info, result_summary, log_f) summary_writer.close() if avg_ckpts: best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) print_step_info("# Averaged Best %s, " % metric, best_global_step, info, result_summary, log_f) summary_writer.close() return final_eval_metrics, global_step
def train(hparams, scope=None, target_session=""): """Train a translation model.""" log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval print(hparams.attention) if hparams.attention.strip() == "": model_creator = nmt_model.Model print("using nmt model " + hparams.attention + "done") elif hparams.attention_architecture == "standard": model_creator = attention_model.AttentionModel print("using attention model " + hparams.attention + "done") elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]: model_creator = gnmt_model.GNMTModel print("using gnmt model") else: raise ValueError("Unknown model architecture") train_model = model_helper.create_train_model(model_creator, hparams, scope) eval_model = model_helper.create_eval_model(model_creator, hparams, scope) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) # Preload data for sample decoding. dev_src_file = hparams.dev_src dev_tgt_file = hparams.dev_tgt sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" # # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 model_dir = hparams.out_dir # TensorFlow model config_proto = utils.get_config_proto( log_device_placement=log_device_placement) train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_model_Alveo( train_model.model, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. stats = init_stats() speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval( session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) while global_step < num_train_steps: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hparams, summary_writer) train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary and accumulate statistics global_step = update_stats(stats, summary_writer, start_time, step_result) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = check_stats(stats, global_step, steps_per_stats, hparams, log_f) if is_overflow: break # Reset statistics stats = init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) # Evaluate on dev/test run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) dev_scores, test_scores, _ = run_external_eval( infer_model, infer_sess, model_dir, hparams, summary_writer) # Done training loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "translate.ckpt"), global_step=global_step) result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) summary_writer.close() utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") summary_writer = tf.summary.FileWriter( os.path.join(best_model_dir, summary_name), infer_model.graph) result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Best %s, step %d " "step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)