def run_internal_eval(self, eval_model, eval_sess, model_dir, summary_writer, use_test_set=True): """Compute internal evaluation (perplexity) for both dev / test.""" with eval_model.graph.as_default(): loaded_eval_model, global_step = model_helper.create_or_load_model( eval_model.model, model_dir, eval_sess, "eval") dev_file = self.config.dev_data dev_eval_iterator_feed_dict = { eval_model.eval_file_placeholder: dev_file } dev_ppl = self._internal_eval(loaded_eval_model, global_step, eval_sess, eval_model.iterator, dev_eval_iterator_feed_dict, summary_writer, "dev") log.add_summary(summary_writer, global_step, "dev_ppl", dev_ppl) if dev_ppl < self.config.best_dev_ppl: loaded_eval_model.saver.save(eval_sess, os.path.join(self.config.best_dev_ppl_dir, 'taware.ckpt'), global_step=global_step) test_ppl = None if use_test_set: test_file = self.config.test_data test_eval_iterator_feed_dict = { eval_model.eval_file_placeholder: test_file } test_ppl = self._internal_eval(loaded_eval_model, global_step, eval_sess, eval_model.iterator, test_eval_iterator_feed_dict, summary_writer, "test") return dev_ppl, test_ppl
def _internal_eval(self, model, global_step, sess, iterator, iterator_feed_dict, summary_writer, label): """Computing perplexity.""" sess.run(iterator.initializer, feed_dict=iterator_feed_dict) ppl = model_helper.compute_perplexity(model, sess, label) log.add_summary(summary_writer, global_step, "%s_ppl" % label, ppl) return ppl
def run_internal_eval(self, eval_model, eval_sess, model_dir, summary_writer, use_test_set=True): """Compute internal evaluation (perplexity) for both dev / test.""" with eval_model.graph.as_default(): loaded_eval_model, global_step = model_helper.create_or_load_model( eval_model.model, model_dir, eval_sess, "eval") dev_eval_iterator_feed_dict = { eval_model.eval_file_placeholder: self.config.dev_data } eval_sess.run(eval_model.iterator.initializer, feed_dict=dev_eval_iterator_feed_dict) dev_ppl = model_helper.compute_perplexity(loaded_eval_model, eval_sess, "dev") log.add_summary(summary_writer, global_step, "dev_ppl", dev_ppl) if dev_ppl < self.config.best_dev_ppl: loaded_eval_model.saver.save( eval_sess, os.path.join(self.config.best_dev_ppl_dir, '{}.ckpt'.format(self._get_checkpoint_name())), global_step=global_step) test_ppl = None if use_test_set: dev_eval_iterator_feed_dict = { eval_model.eval_file_placeholder: self.config.test_data } eval_sess.run(eval_model.iterator.initializer, feed_dict=dev_eval_iterator_feed_dict) test_ppl = model_helper.compute_perplexity(loaded_eval_model, eval_sess, "test") log.add_summary(summary_writer, global_step, "test_ppl", test_ppl) return dev_ppl, test_ppl
def _external_eval(self, model, global_step, sess, iterator, iterator_feed_dict, eval_file, label, summary_writer, save_on_best): """External evaluation such as BLEU and ROUGE scores.""" out_dir = self.config.model_dir decode = global_step > 0 if decode: log.print_out("# External evaluation, global step %d" % global_step) sess.run(iterator.initializer, feed_dict=iterator_feed_dict) output = os.path.join(out_dir, "output_%s" % label) scores = eval_metric.decode_and_evaluate( label, model, sess, output, ref_file=eval_file, metrics=self.config.metrics, beam_width=self.config.beam_width, decode=decode) # Save on best metrics if decode: for metric in self.config.metrics: log.add_summary(summary_writer, global_step, "%s_%s" % (label, metric), scores[metric]) # metric: larger is better if save_on_best and scores[metric] > getattr( self.config, "best_" + metric): setattr(self.config, "best_" + metric, scores[metric]) model.saver.save(sess, os.path.join( getattr(self.config, "best_" + metric + "_dir"), "vanilla.ckpt"), global_step=model.global_step) # self.config.save(out_dir) return scores
def train(self, target_session="", scope=None): assert self.config.num_turns >= 2 if self.config.is_pretrain_enabled(): assert self.config.num_pretrain_turns >= 2 assert self.config.num_turns >= self.config.num_pretrain_turns out_dir = self.config.model_dir steps_per_stats = self.config.steps_per_stats steps_per_eval = 20 * steps_per_stats _helper = self._get_model_helper() self._pre_model_creation() train_model = _helper.create_train_model(self.config, scope) eval_model = _helper.create_eval_model(self.config, scope) infer_model = _helper.create_infer_model(self.config, scope) self._post_model_creation(train_model, eval_model, infer_model) # Preload data for sample decoding. dev_file = self.config.dev_data eval_data = self._load_data(dev_file, include_target=True) summary_name = "train_log" # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") log.print_out("# log_file=%s" % log_file, log_f) self.config.save() log.print_out("# Configs saved") avg_step_time = 0.0 # TensorFlow model config_proto = model_helper.get_config_proto(self.config.log_device) train_sess = tf.Session( target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session( target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session( target=target_session, config=config_proto, graph=infer_model.graph) # Pretraining num_pretrain_steps = 0 if self.config.is_pretrain_enabled(): num_pretrain_steps = self.config.num_pretrain_steps pretrain_model = _helper.create_pretrain_model(self.config, scope) with tf.Session( target=target_session, config=config_proto, graph=pretrain_model.graph) as pretrain_sess: self.pretrain(pretrain_sess, pretrain_model, log_f) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, self.config.model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) last_stats_step = global_step last_eval_step = global_step patience = self.config.patience stats = self.init_stats() speed, train_ppl = 0.0, 0.0 start_train_time = time.time() log.print_out( "# Start step %d, epoch %d, lr %g, %s" % (global_step, self.config.epoch, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = self.config.batch_size * self.config.epoch_step log.print_out("# Init train iterator for %d steps, skipping %d elements" % (self.config.num_train_steps, skip_count)) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) while self.config.epoch < self.config.num_train_epochs and patience > 0: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) self.config.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. sw = Stopwatch() self.run_sample_decode(infer_model, infer_sess, self.config.model_dir, summary_writer, eval_data) # if self.config.enable_epoch_evals: # dev_ppl, test_ppl = self.run_full_eval(infer_model, eval_model, # infer_sess, eval_sess, # out_dir, # fs.file_name(self.config.test_data) + '_' + global_step, # summary_writer) # log.print_out( # "%% done epoch %d #%d step %d - dev_ppl: %.2f test_ppl: %.2f @ eval time: %ds" % # (self.config.epoch, self.config.epoch_step, global_step, dev_ppl, test_ppl, sw.elapsed())) # else: log.print_out( "## Done epoch %d in %d steps. step %d @ eval time: %ds" % (self.config.epoch, self.config.epoch_step, global_step, sw.elapsed())) self.config.epoch += 1 self.config.epoch_step = 0 self.config.save() train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary and accumulate statistics global_step = self.update_stats(stats, summary_writer, start_time, step_result) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step train_ppl, speed, is_overflow = self.check_stats(stats, global_step, steps_per_stats, log_f) if is_overflow: break # Reset statistics stats = self.init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step log.print_out("# Save eval, global step %d" % global_step) log.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, self.config.checkpoint_file, global_step=global_step) # Evaluate on dev self.run_sample_decode(infer_model, infer_sess, out_dir, summary_writer, eval_data) dev_ppl, _ = self.run_internal_eval(eval_model, eval_sess, out_dir, summary_writer, use_test_set=False) if dev_ppl < self.config.best_dev_ppl: self.config.best_dev_ppl = dev_ppl patience = self.config.patience log.print_out(' ** Best model thus far, ep {}|{} dev_ppl {:.3f}'.format( self.config.epoch, self.config.epoch_step, dev_ppl)) elif dev_ppl > self.config.degrade_threshold * self.config.best_dev_ppl: patience -= 1 log.print_out( ' worsened, ep {}|{} patience {} best_dev_ppl {:.3f}'.format( self.config.epoch, self.config.epoch_step, patience, self.config.best_dev_ppl)) # Save config parameters self.config.save() # Done training loaded_train_model.saver.save( train_sess, self.config.checkpoint_file, global_step=global_step) if self.config.enable_final_eval: dev_ppl, test_ppl = self.run_full_eval(infer_model, eval_model, infer_sess, eval_sess, out_dir, fs.file_name(self.config.test_data) + '_final', summary_writer) log.print_out( "# Final, step %d ep %d/%d lr %g " "step-time %.2f wps %.2fK train_ppl %.2f, dev_ppl %.2f, test_ppl %.2f, %s" % (global_step, self.config.epoch, self.config.epoch_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, dev_ppl, test_ppl, time.ctime()), log_f) else: log.print_out( "# Final, step %d ep %d/%d lr %g " "step-time %.2f wps %.2fK train_ppl %.2f best_dev_ppl %.2f, %s" % (global_step, self.config.epoch, self.config.epoch_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, self.config.best_dev_ppl, time.ctime()), log_f) log.print_time("# Done training!", start_train_time) summary_writer.close() eval_sess.close() infer_sess.close() train_sess.close()
def train(self, target_session="", scope=None): out_dir = self.config.model_dir model_dir = out_dir num_train_steps = self.config.num_train_steps steps_per_stats = self.config.steps_per_stats # steps_per_external_eval = self.config.steps_per_external_eval steps_per_eval = 20 * steps_per_stats # if not steps_per_external_eval: # steps_per_external_eval = 5 * steps_per_eval self._pre_model_creation() train_model = taware_helper.create_train_model(taware_model.TopicAwareSeq2SeqModel, self.config, scope) eval_model = taware_helper.create_eval_model(taware_model.TopicAwareSeq2SeqModel, self.config, scope) infer_model = taware_helper.create_infer_model(taware_model.TopicAwareSeq2SeqModel, self.config, scope) # Preload data for sample decoding. dev_file = self.config.dev_data eval_data = self._load_data(dev_file, include_target=True) summary_name = "train_log" # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") log.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # TensorFlow model config_proto = models.model_helper.get_config_proto(self.config.log_device) train_sess = tf.Session( target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session( target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session( target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) # First evaluation # self.run_full_eval( # model_dir, infer_model, infer_sess, # eval_model, eval_sess, summary_writer, eval_data) last_stats_step = global_step last_eval_step = global_step # last_external_eval_step = global_step patience = self.config.patience # This is the training loop. stats = self.init_stats() speed, train_ppl = 0.0, 0.0 start_train_time = time.time() log.print_out( "# Start step %d, epoch %d, lr %g, %s" % (global_step, self.config.epoch, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) self.config.save() log.print_out("# Configs saved") # Initialize all of the iterators skip_count = self.config.batch_size * self.config.epoch_step log.print_out("# Init train iterator for %d steps, skipping %d elements" % (self.config.num_train_steps, skip_count)) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) while self.config.epoch < self.config.num_train_epochs and patience > 0: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) self.config.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. sw = Stopwatch() log.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) self.run_sample_decode(infer_model, infer_sess, model_dir, summary_writer, eval_data) log.print_out( "## Done epoch %d in %d steps. step %d @ eval time: %ds" % (self.config.epoch, self.config.epoch_step, global_step, sw.elapsed())) self.config.epoch += 1 self.config.epoch_step = 0 self.config.save() train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary and accumulate statistics global_step = self.update_stats(stats, summary_writer, start_time, step_result) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step train_ppl, speed, is_overflow = self.check_stats(stats, global_step, steps_per_stats, log_f) if is_overflow: break # Reset statistics stats = self.init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step log.print_out("# Save eval, global step %d" % global_step) log.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save( train_sess, self.config.checkpoint_file, global_step=global_step) # Evaluate on dev self.run_sample_decode(infer_model, infer_sess, model_dir, summary_writer, eval_data) dev_ppl, _ = self.run_internal_eval(eval_model, eval_sess, model_dir, summary_writer, use_test_set=False) if dev_ppl < self.config.best_dev_ppl: self.config.best_dev_ppl = dev_ppl patience = self.config.patience log.print_out(' ** Best model thus far, ep {}|{} dev_ppl {:.3f}'.format( self.config.epoch, self.config.epoch_step, dev_ppl)) elif dev_ppl > self.config.degrade_threshold * self.config.best_dev_ppl: patience -= 1 log.print_out( ' worsened, ep {}|{} patience {} best_dev_ppl {:.3f}'.format( self.config.epoch, self.config.epoch_step, self.config.patience, self.config.best_dev_ppl)) # Save config parameters self.config.save() # if global_step - last_external_eval_step >= steps_per_external_eval: # last_external_eval_step = global_step # # # Save checkpoint # loaded_train_model.saver.save( # train_sess, # self.config.checkpoint_file, # global_step=global_step) # self.run_sample_decode(infer_model, infer_sess, # model_dir, summary_writer, eval_data) # dev_scores, test_scores, _ = self.run_external_eval(infer_model, infer_sess, model_dir, summary_writer) # Done training loaded_train_model.saver.save( train_sess, self.config.checkpoint_file, global_step=global_step) # result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = self.run_full_eval( # model_dir, infer_model, infer_sess, # eval_model, eval_sess, # summary_writer, eval_data) dev_scores, test_scores, dev_ppl, test_ppl = None, None, None, None result_summary = "" log.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) log.print_time("# Done training!", start_train_time) summary_writer.close() # log.print_out("# Start evaluating saved best models.") # for metric in self.config.metrics: # best_model_dir = getattr(self.config, "best_" + metric + "_dir") # summary_writer = tf.summary.FileWriter( # os.path.join(best_model_dir, summary_name), infer_model.graph) # result_summary, best_global_step, _, _, _, _ = self.run_full_eval( # best_model_dir, infer_model, infer_sess, eval_model, eval_sess, # summary_writer, eval_data) # log.print_out("# Best %s, step %d " # "step-time %.2f wps %.2fK, %s, %s" % # (metric, best_global_step, avg_step_time, speed, # result_summary, time.ctime()), log_f) # summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)