def run_internal_eval(self, eval_model, eval_sess, model_dir, summary_writer, use_test_set=True): """Compute internal evaluation (perplexity) for both dev / test.""" with eval_model.graph.as_default(): loaded_eval_model, global_step = model_helper.create_or_load_model( eval_model.model, model_dir, eval_sess, "eval") dev_file = self.config.dev_data dev_eval_iterator_feed_dict = { eval_model.eval_file_placeholder: dev_file } dev_ppl = self._internal_eval(loaded_eval_model, global_step, eval_sess, eval_model.iterator, dev_eval_iterator_feed_dict, summary_writer, "dev") log.add_summary(summary_writer, global_step, "dev_ppl", dev_ppl) if dev_ppl < self.config.best_dev_ppl: loaded_eval_model.saver.save(eval_sess, os.path.join(self.config.best_dev_ppl_dir, 'taware.ckpt'), global_step=global_step) test_ppl = None if use_test_set: test_file = self.config.test_data test_eval_iterator_feed_dict = { eval_model.eval_file_placeholder: test_file } test_ppl = self._internal_eval(loaded_eval_model, global_step, eval_sess, eval_model.iterator, test_eval_iterator_feed_dict, summary_writer, "test") return dev_ppl, test_ppl
def __init__(self, hparams): self.hparams = hparams # print("====test__init__==\n") # Data locations self.out_dir = hparams.out_dir # print("our_dir:", self.out_dir) self.model_dir = os.path.join(self.out_dir, 'ckpts') # print("model_dir:", self.model_dir) # Create models attention_option = hparams.attention_option if attention_option: model_creator = AttentionModel else: model_creator = BasicModel self.infer_model = model_helper.create_infer_model( hparams=hparams, model_creator=model_creator) # Sessions config_proto = utils.get_config_proto() self.infer_sess = tf.Session(config=config_proto, graph=self.infer_model.graph) # EOS self.tgt_eos = Vocabulary.EOS.encode("utf-8") # Load infer model with self.infer_model.graph.as_default(): self.loaded_infer_model, self.global_step = model_helper.create_or_load_model( self.infer_model.model, self.model_dir, self.infer_sess, "infer")
def run_external_evaluation(infer_model, infer_sess, model_dir, hparams, summary_writer, save_on_best_dev): with infer_model.graph.as_default(): # Load the model from checkpoint. It automatically loads the latest checkpoint loaded_infer_model, global_step = model_helper.create_or_load_model( model=infer_model.model, model_dir=model_dir, session=infer_sess, name="infer" ) # Fill the feed_dict for the evaluation dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) inference_dev_data = inference.load_data(dev_src_file) dev_infer_iterator_feed_dict = { infer_model.src_placeholder: inference_dev_data, infer_model.batch_size_placeholder: hparams.infer_batch_size } dev_scores = _external_eval( model=loaded_infer_model, global_step=global_step, sess=infer_sess, hparams=hparams, iterator=infer_model.iterator, iterator_feed_dict=dev_infer_iterator_feed_dict, tgt_file=dev_tgt_file, label="dev", summary_writer=summary_writer, save_on_best_dev=save_on_best_dev ) test_scores = None if hparams.test_prefix: # Create the test data test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) inference_test_data = inference.load_data(test_src_file) test_infer_iterator_feed_dict = { infer_model.src_placeholder: inference_test_data, infer_model.batch_size_placeholder: hparams.infer_batch_size } # Run evaluation on the test dataset test_scores = _external_eval( model=loaded_infer_model, global_step=global_step, sess=infer_sess, hparams=hparams, iterator=infer_model.iterator, iterator_feed_dict=test_infer_iterator_feed_dict, tgt_file=test_tgt_file, label="test", summary_writer=summary_writer, save_on_best_dev=False # We do not use the test set at all in training as that means overfitting ) return dev_scores, test_scores, global_step
def run_sample_decode(self, infer_model, infer_sess, model_dir, summary_writer, eval_data): """Sample decode a random sentence from src_data.""" with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer") self.__sample_decode(loaded_infer_model, global_step, infer_sess, infer_model.iterator, eval_data, infer_model.src_placeholder, infer_model.batch_size_placeholder, summary_writer)
def run_external_eval(self, infer_model, infer_sess, model_dir, summary_writer, save_best_dev=True, use_test_set=True): """Compute external evaluation (bleu, rouge, etc.) for both dev / test.""" with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer") dev_infer_iterator_feed_dict = { infer_model.src_placeholder: self._load_data(self.config.dev_data), infer_model.batch_size_placeholder: self.config.infer_batch_size, } dev_scores = self._external_eval(loaded_infer_model, global_step, infer_sess, infer_model.iterator, dev_infer_iterator_feed_dict, self.config.dev_data, "dev", summary_writer, save_on_best=save_best_dev) test_scores = None if use_test_set: test_file = self.config.test_data test_infer_iterator_feed_dict = { infer_model.src_placeholder: self._load_data(test_file), infer_model.batch_size_placeholder: self.config.infer_batch_size, } test_scores = self._external_eval(loaded_infer_model, global_step, infer_sess, infer_model.iterator, test_infer_iterator_feed_dict, test_file, "test", summary_writer, save_on_best=False) return dev_scores, test_scores, global_step
def chat(self): """Accept a input str and get response by trained model.""" model_dir = self.model_dir infer_model = self.infer_model infer_sess = self.infer_sess beam_width = self.hparams.beam_width # Load infer model with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer") # Warm up jieba jieba.lcut("jieba") print("请输入字母'q' 或者 '退出'标示结束!\n\n") while True: input_str = input('Me > ') if not input_str.strip(): continue if input_str == "q" or input_str == "退出": break input_seg = jieba.lcut(input_str) start_time = time.time() iterator_feed_dict = { infer_model.src_data_placeholder: input_seg, infer_model.batch_size_placeholder: 1 } infer_sess.run(self.infer_model.iterator.initializer, feed_dict=iterator_feed_dict) sample_words = loaded_infer_model.decode(infer_sess) if beam_width > 0: # Get a random answer. beam_id = random.randint(0, beam_width - 1) sample_words = sample_words[beam_id] response = self._get_response(sample_words) response = "".join(re.split(" ", response)) print("AI > %s (%.4fs)" % (response, time.time() - start_time))
def sample_decode(self, num_sentences=1): """Sample decode num_sentences random sentence from src_data.""" model_dir = self.model_dir infer_model = self.infer_model infer_sess = self.infer_sess train_src_file = self.train_src_file train_tgt_file = self.train_tgt_file beam_width = self.hparams.beam_width start_time = time.time() # Load infer model with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer") src_data = open(train_src_file, encoding='utf-8').readlines() tgt_data = open(train_tgt_file, encoding='utf-8').readlines() for _ in range(num_sentences): decode_id = random.randint(0, len(src_data) - 1) print("# Decoding sentence %d" % decode_id) iterator_feed_dict = { infer_model.src_data_placeholder: [src_data[decode_id]], infer_model.batch_size_placeholder: 1 } infer_sess.run( self.infer_model.iterator.initializer, feed_dict=iterator_feed_dict) sample_words = loaded_infer_model.decode(infer_sess) if beam_width > 0: # get the top translation. sample_words = sample_words[0] response = self._get_response(sample_words) print(" src: %s" % src_data[decode_id], end='') print(" ref: %s" % tgt_data[decode_id], end='') print(" bot: %s" % response) print(" tim: %.4fs" % (time.time() - start_time))
def _get_eval_perplexity(self, name): model_dir = self.model_dir eval_model = self.eval_model eval_sess = self.eval_sess with eval_model.graph.as_default(): loaded_eval_model, global_step = model_helper.create_or_load_model( eval_model.model, model_dir, eval_sess, 'eval') dev_eval_iterator_feed_dict = { eval_model.src_file_placeholder: self.dev_src_file, eval_model.tgt_file_placeholder: self.dev_tgt_file } dev_ppl = eval_utils.internal_eval( eval_model, global_step, eval_sess, eval_model.iterator, dev_eval_iterator_feed_dict, name) return dev_ppl
def run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer): """Compute internal evaluation (perplexity) for both dev / test.""" with eval_model.graph.as_default(): # Load the latest checkpoint from file loaded_eval_model, global_step = model_helper.create_or_load_model( model=eval_model.model, model_dir=model_dir, session=eval_sess, name="eval" ) # Fill the feed_dict for the evaluation dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) dev_eval_iterator_feed_dict = { eval_model.src_file_placeholder: dev_src_file, eval_model.tgt_file_placeholder: dev_tgt_file } # Run evaluation on the development (validation) dataset dev_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess, iterator=eval_model.iterator, iterator_feed_dict=dev_eval_iterator_feed_dict, summary_writer=summary_writer, label='dev') test_ppl = None if hparams.test_prefix: # Create the test data test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src) test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt) test_eval_iterator_feed_dict = { eval_model.src_file_placeholder: test_src_file, eval_model.tgt_file_placeholder: test_tgt_file } # Run evaluation on the test dataset test_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess, iterator=eval_model.iterator, iterator_feed_dict=test_eval_iterator_feed_dict, summary_writer=summary_writer, label='test') return dev_ppl, test_ppl
def run_internal_eval(self, eval_model, eval_sess, model_dir, summary_writer, use_test_set=True): """Compute internal evaluation (perplexity) for both dev / test.""" with eval_model.graph.as_default(): loaded_eval_model, global_step = model_helper.create_or_load_model( eval_model.model, model_dir, eval_sess, "eval") dev_eval_iterator_feed_dict = { eval_model.eval_file_placeholder: self.config.dev_data } eval_sess.run(eval_model.iterator.initializer, feed_dict=dev_eval_iterator_feed_dict) dev_ppl = model_helper.compute_perplexity(loaded_eval_model, eval_sess, "dev") log.add_summary(summary_writer, global_step, "dev_ppl", dev_ppl) if dev_ppl < self.config.best_dev_ppl: loaded_eval_model.saver.save( eval_sess, os.path.join(self.config.best_dev_ppl_dir, '{}.ckpt'.format(self._get_checkpoint_name())), global_step=global_step) test_ppl = None if use_test_set: dev_eval_iterator_feed_dict = { eval_model.eval_file_placeholder: self.config.test_data } eval_sess.run(eval_model.iterator.initializer, feed_dict=dev_eval_iterator_feed_dict) test_ppl = model_helper.compute_perplexity(loaded_eval_model, eval_sess, "test") log.add_summary(summary_writer, global_step, "test_ppl", test_ppl) return dev_ppl, test_ppl
def run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, src_data, tgt_data): """ Sample decode a random sentence from the source data. Used to print the tangible progress of the model. :param infer_model: The model used to produce the response. :param model_dir: directory which contains the trained model :param summary_writer: An instance of a tensorflow Summary writer :return: """ with infer_model.graph.as_default(): # Load the model from checkpoint. It automatically loads the latest checkpoint loaded_infer_model, global_step = model_helper.create_or_load_model( model=infer_model.model, model_dir=model_dir, session=infer_sess, name="infer" ) _sample_decode(model=loaded_infer_model, global_step=global_step, sess=infer_sess, hparams=hparams, iterator=infer_model.iterator, src_data=src_data, tgt_data=tgt_data, iterator_src_placeholder=infer_model.src_placeholder, iterator_batch_size_placeholder=infer_model.batch_size_placeholder, summary_writer=summary_writer)
def run_full_eval(self, infer_model, eval_model, infer_sess, eval_sess, model_dir, label, summary_writer): dev_ppl, test_ppl = self.run_internal_eval(eval_model, eval_sess, model_dir, summary_writer, use_test_set=True) with infer_model.graph.as_default(): loaded_infer_model, _ = model_helper.create_or_load_model( infer_model.model, self.config.model_dir, infer_sess, "infer") infer_feed_dict = { infer_model.src_placeholder: self._load_data(self.config.test_data), infer_model.batch_size_placeholder: self.config.infer_batch_size, } self._decode_and_evaluate(loaded_infer_model, infer_sess, infer_feed_dict, label=label) return dev_ppl, test_ppl
def train(self): hparams = self.hparams train_model = self.train_model train_sess = self.train_sess model_dir = self.model_dir steps_per_stats = hparams.steps_per_stats num_train_steps = hparams.num_train_steps summary_name = "train_log" # Load train model with self.train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( self.train_model.model, self.model_dir, self.train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(self.out_dir, summary_name), train_model.graph) # Initialize dataset iterator train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) loss_track = [] training_start_time = time.time() epoch_count = 0 last_stats_step = global_step stats = train_utils.init_stats() best_bleu_score = 0 while global_step < num_train_steps: # Run a training step start_time = time.time() try: train_result = loaded_train_model.train(train_sess) except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. epoch_count += 1 print("# Finished epoch %d, step %d." % (epoch_count, global_step)) # Save model params loaded_train_model.saver.save( train_sess, os.path.join(model_dir, "chatbot.ckpt"), global_step=global_step) # Do evaluation self.eval(best_bleu_score) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary and accumulate statistics global_step = train_utils.update_stats( stats, summary_writer, start_time, train_result.values(), best_bleu_score) loss_track.append(train_result['train_loss']) if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step is_overflow = train_utils.check_stats(stats, global_step, steps_per_stats) if is_overflow: break # Reset statistics stats = train_utils.init_stats() # Training done. loaded_train_model.saver.save( train_sess, os.path.join(model_dir, "chatbot.ckpt"), global_step=global_step) summary_writer.close() print('Training done. Total time: %.4f' % (time.time() - training_start_time))
def train(self, target_session="", scope=None): out_dir = self.config.model_dir model_dir = out_dir num_train_steps = self.config.num_train_steps steps_per_stats = self.config.steps_per_stats # steps_per_external_eval = self.config.steps_per_external_eval steps_per_eval = 20 * steps_per_stats # if not steps_per_external_eval: # steps_per_external_eval = 5 * steps_per_eval self._pre_model_creation() train_model = taware_helper.create_train_model(taware_model.TopicAwareSeq2SeqModel, self.config, scope) eval_model = taware_helper.create_eval_model(taware_model.TopicAwareSeq2SeqModel, self.config, scope) infer_model = taware_helper.create_infer_model(taware_model.TopicAwareSeq2SeqModel, self.config, scope) # Preload data for sample decoding. dev_file = self.config.dev_data eval_data = self._load_data(dev_file, include_target=True) summary_name = "train_log" # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") log.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # TensorFlow model config_proto = models.model_helper.get_config_proto(self.config.log_device) train_sess = tf.Session( target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session( target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session( target=target_session, config=config_proto, graph=infer_model.graph) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) # First evaluation # self.run_full_eval( # model_dir, infer_model, infer_sess, # eval_model, eval_sess, summary_writer, eval_data) last_stats_step = global_step last_eval_step = global_step # last_external_eval_step = global_step patience = self.config.patience # This is the training loop. stats = self.init_stats() speed, train_ppl = 0.0, 0.0 start_train_time = time.time() log.print_out( "# Start step %d, epoch %d, lr %g, %s" % (global_step, self.config.epoch, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) self.config.save() log.print_out("# Configs saved") # Initialize all of the iterators skip_count = self.config.batch_size * self.config.epoch_step log.print_out("# Init train iterator for %d steps, skipping %d elements" % (self.config.num_train_steps, skip_count)) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) while self.config.epoch < self.config.num_train_epochs and patience > 0: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) self.config.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. sw = Stopwatch() log.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) self.run_sample_decode(infer_model, infer_sess, model_dir, summary_writer, eval_data) log.print_out( "## Done epoch %d in %d steps. step %d @ eval time: %ds" % (self.config.epoch, self.config.epoch_step, global_step, sw.elapsed())) self.config.epoch += 1 self.config.epoch_step = 0 self.config.save() train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary and accumulate statistics global_step = self.update_stats(stats, summary_writer, start_time, step_result) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step train_ppl, speed, is_overflow = self.check_stats(stats, global_step, steps_per_stats, log_f) if is_overflow: break # Reset statistics stats = self.init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step log.print_out("# Save eval, global step %d" % global_step) log.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save( train_sess, self.config.checkpoint_file, global_step=global_step) # Evaluate on dev self.run_sample_decode(infer_model, infer_sess, model_dir, summary_writer, eval_data) dev_ppl, _ = self.run_internal_eval(eval_model, eval_sess, model_dir, summary_writer, use_test_set=False) if dev_ppl < self.config.best_dev_ppl: self.config.best_dev_ppl = dev_ppl patience = self.config.patience log.print_out(' ** Best model thus far, ep {}|{} dev_ppl {:.3f}'.format( self.config.epoch, self.config.epoch_step, dev_ppl)) elif dev_ppl > self.config.degrade_threshold * self.config.best_dev_ppl: patience -= 1 log.print_out( ' worsened, ep {}|{} patience {} best_dev_ppl {:.3f}'.format( self.config.epoch, self.config.epoch_step, self.config.patience, self.config.best_dev_ppl)) # Save config parameters self.config.save() # if global_step - last_external_eval_step >= steps_per_external_eval: # last_external_eval_step = global_step # # # Save checkpoint # loaded_train_model.saver.save( # train_sess, # self.config.checkpoint_file, # global_step=global_step) # self.run_sample_decode(infer_model, infer_sess, # model_dir, summary_writer, eval_data) # dev_scores, test_scores, _ = self.run_external_eval(infer_model, infer_sess, model_dir, summary_writer) # Done training loaded_train_model.saver.save( train_sess, self.config.checkpoint_file, global_step=global_step) # result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = self.run_full_eval( # model_dir, infer_model, infer_sess, # eval_model, eval_sess, # summary_writer, eval_data) dev_scores, test_scores, dev_ppl, test_ppl = None, None, None, None result_summary = "" log.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) log.print_time("# Done training!", start_train_time) summary_writer.close() # log.print_out("# Start evaluating saved best models.") # for metric in self.config.metrics: # best_model_dir = getattr(self.config, "best_" + metric + "_dir") # summary_writer = tf.summary.FileWriter( # os.path.join(best_model_dir, summary_name), infer_model.graph) # result_summary, best_global_step, _, _, _, _ = self.run_full_eval( # best_model_dir, infer_model, infer_sess, eval_model, eval_sess, # summary_writer, eval_data) # log.print_out("# Best %s, step %d " # "step-time %.2f wps %.2fK, %s, %s" % # (metric, best_global_step, avg_step_time, speed, # result_summary, time.ctime()), log_f) # summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def train(self, target_session="", scope=None): assert self.config.num_turns >= 2 if self.config.is_pretrain_enabled(): assert self.config.num_pretrain_turns >= 2 assert self.config.num_turns >= self.config.num_pretrain_turns out_dir = self.config.model_dir steps_per_stats = self.config.steps_per_stats steps_per_eval = 20 * steps_per_stats _helper = self._get_model_helper() self._pre_model_creation() train_model = _helper.create_train_model(self.config, scope) eval_model = _helper.create_eval_model(self.config, scope) infer_model = _helper.create_infer_model(self.config, scope) self._post_model_creation(train_model, eval_model, infer_model) # Preload data for sample decoding. dev_file = self.config.dev_data eval_data = self._load_data(dev_file, include_target=True) summary_name = "train_log" # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") log.print_out("# log_file=%s" % log_file, log_f) self.config.save() log.print_out("# Configs saved") avg_step_time = 0.0 # TensorFlow model config_proto = model_helper.get_config_proto(self.config.log_device) train_sess = tf.Session( target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session( target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session( target=target_session, config=config_proto, graph=infer_model.graph) # Pretraining num_pretrain_steps = 0 if self.config.is_pretrain_enabled(): num_pretrain_steps = self.config.num_pretrain_steps pretrain_model = _helper.create_pretrain_model(self.config, scope) with tf.Session( target=target_session, config=config_proto, graph=pretrain_model.graph) as pretrain_sess: self.pretrain(pretrain_sess, pretrain_model, log_f) with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model( train_model.model, self.config.model_dir, train_sess, "train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) last_stats_step = global_step last_eval_step = global_step patience = self.config.patience stats = self.init_stats() speed, train_ppl = 0.0, 0.0 start_train_time = time.time() log.print_out( "# Start step %d, epoch %d, lr %g, %s" % (global_step, self.config.epoch, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # Initialize all of the iterators skip_count = self.config.batch_size * self.config.epoch_step log.print_out("# Init train iterator for %d steps, skipping %d elements" % (self.config.num_train_steps, skip_count)) train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) while self.config.epoch < self.config.num_train_epochs and patience > 0: ### Run a step ### start_time = time.time() try: step_result = loaded_train_model.train(train_sess) self.config.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. sw = Stopwatch() self.run_sample_decode(infer_model, infer_sess, self.config.model_dir, summary_writer, eval_data) # if self.config.enable_epoch_evals: # dev_ppl, test_ppl = self.run_full_eval(infer_model, eval_model, # infer_sess, eval_sess, # out_dir, # fs.file_name(self.config.test_data) + '_' + global_step, # summary_writer) # log.print_out( # "%% done epoch %d #%d step %d - dev_ppl: %.2f test_ppl: %.2f @ eval time: %ds" % # (self.config.epoch, self.config.epoch_step, global_step, dev_ppl, test_ppl, sw.elapsed())) # else: log.print_out( "## Done epoch %d in %d steps. step %d @ eval time: %ds" % (self.config.epoch, self.config.epoch_step, global_step, sw.elapsed())) self.config.epoch += 1 self.config.epoch_step = 0 self.config.save() train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary and accumulate statistics global_step = self.update_stats(stats, summary_writer, start_time, step_result) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step train_ppl, speed, is_overflow = self.check_stats(stats, global_step, steps_per_stats, log_f) if is_overflow: break # Reset statistics stats = self.init_stats() if global_step - last_eval_step >= steps_per_eval: last_eval_step = global_step log.print_out("# Save eval, global step %d" % global_step) log.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, self.config.checkpoint_file, global_step=global_step) # Evaluate on dev self.run_sample_decode(infer_model, infer_sess, out_dir, summary_writer, eval_data) dev_ppl, _ = self.run_internal_eval(eval_model, eval_sess, out_dir, summary_writer, use_test_set=False) if dev_ppl < self.config.best_dev_ppl: self.config.best_dev_ppl = dev_ppl patience = self.config.patience log.print_out(' ** Best model thus far, ep {}|{} dev_ppl {:.3f}'.format( self.config.epoch, self.config.epoch_step, dev_ppl)) elif dev_ppl > self.config.degrade_threshold * self.config.best_dev_ppl: patience -= 1 log.print_out( ' worsened, ep {}|{} patience {} best_dev_ppl {:.3f}'.format( self.config.epoch, self.config.epoch_step, patience, self.config.best_dev_ppl)) # Save config parameters self.config.save() # Done training loaded_train_model.saver.save( train_sess, self.config.checkpoint_file, global_step=global_step) if self.config.enable_final_eval: dev_ppl, test_ppl = self.run_full_eval(infer_model, eval_model, infer_sess, eval_sess, out_dir, fs.file_name(self.config.test_data) + '_final', summary_writer) log.print_out( "# Final, step %d ep %d/%d lr %g " "step-time %.2f wps %.2fK train_ppl %.2f, dev_ppl %.2f, test_ppl %.2f, %s" % (global_step, self.config.epoch, self.config.epoch_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, dev_ppl, test_ppl, time.ctime()), log_f) else: log.print_out( "# Final, step %d ep %d/%d lr %g " "step-time %.2f wps %.2fK train_ppl %.2f best_dev_ppl %.2f, %s" % (global_step, self.config.epoch, self.config.epoch_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, self.config.best_dev_ppl, time.ctime()), log_f) log.print_time("# Done training!", start_train_time) summary_writer.close() eval_sess.close() infer_sess.close() train_sess.close()
def infer(self, num_print_per_batch=0): model_dir = self.model_dir out_dir = self.out_dir dev_src_file = self.dev_src_file dev_tgt_file = self.dev_tgt_file infer_batch_size = self.hparams.infer_batch_size beam_width = self.hparams.beam_width infer_model = self.infer_model infer_sess = self.infer_sess infer_output_file = os.path.join(out_dir, 'infer_output') start_time = time.time() print('# Decoding to %s' % infer_output_file) # Load infer model with infer_model.graph.as_default(): loaded_infer_model, global_step = model_helper.create_or_load_model( infer_model.model, model_dir, infer_sess, "infer") with open(dev_src_file, encoding='utf-8') as in_src_file, \ open(dev_tgt_file, encoding='utf-8') as in_tgt_file, \ open(infer_output_file, mode='w', encoding='utf-8') as out_file: infer_src_data = in_src_file.readlines() infer_tgt_data = in_tgt_file.readlines() iterator_feed_dict = { infer_model.src_data_placeholder: infer_src_data, infer_model.batch_size_placeholder: infer_batch_size } infer_sess.run( infer_model.iterator.initializer, feed_dict=iterator_feed_dict) num_sentences = 0 while True: try: # The shape of sample_words is [batch_size, time] or # [beam_width, batch_size, time] when using beam search. sample_words = loaded_infer_model.decode(infer_sess) if beam_width == 0: sample_words = np.expand_dims(sample_words, 0) batch_size = sample_words.shape[1] for sent_id in range(batch_size): beam_id = random.randint(0, beam_width - 1) if beam_width > 0 else 0 response = self._get_response(sample_words[beam_id][sent_id]) out_file.write(response + '\n') if sent_id < num_print_per_batch: sent_id += num_sentences print(" sentence %d" % sent_id) print(" src: %s" % infer_src_data[sent_id], end='') print(" ref: %s" % infer_tgt_data[sent_id], end='') print(" bot: %s" % response) num_sentences += batch_size except tf.errors.OutOfRangeError: utils.print_time( " done, num sentences %d, beam width %d" % (num_sentences, beam_width), start_time) break
def train(hparams, scope=None, target_session=''): """Train the chatbot""" # Initialize some local hyperparameters log_device_placement = hparams.log_device_placement out_dir = hparams.out_dir num_train_steps = hparams.num_train_steps steps_per_stats = hparams.steps_per_stats steps_per_external_eval = hparams.steps_per_external_eval steps_per_eval = 10 * steps_per_stats if not steps_per_external_eval: steps_per_external_eval = 5 * steps_per_eval if hparams.architecture == "simple": model_creator = SimpleModel get_infer_iterator = iterator_utils.get_infer_iterator get_iterator = iterator_utils.get_iterator elif hparams.architecture == "hier": model_creator = HierarchicalModel # Parse some of the arguments now def curry_get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos, src_max_len): return end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos, src_max_len=src_max_len, eou=hparams.eou, dialogue_max_len=hparams.dialogue_max_len) get_infer_iterator = curry_get_infer_iterator def curry_get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos, src_reverse, random_seed, num_buckets, src_max_len=None, tgt_max_len=None, num_threads=4, output_buffer_size=None, skip_count=None): return end2end_iterator_utils.get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos, eou=hparams.eou, src_reverse=src_reverse, random_seed=random_seed, num_dialogue_buckets=num_buckets, src_max_len=src_max_len, tgt_max_len=tgt_max_len, num_threads=num_threads, output_buffer_size=output_buffer_size, skip_count=skip_count) get_iterator = curry_get_iterator else: raise ValueError("Unkown architecture", hparams.architecture) # Create three models which share parameters through the use of checkpoints train_model = create_train_model(model_creator, get_iterator, hparams, scope) eval_model = create_eval_model(model_creator, get_iterator, hparams, scope) infer_model = inference.create_infer_model(model_creator, get_infer_iterator, hparams, scope) # ToDo: adapt for architectures # Preload the data to use for sample decoding dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src) dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt) sample_src_data = inference.load_data(dev_src_file) sample_tgt_data = inference.load_data(dev_tgt_file) summary_name = "train_log" model_dir = hparams.out_dir # Log and output files log_file = os.path.join(out_dir, "log_%d" % time.time()) log_f = tf.gfile.GFile(log_file, mode="a") utils.print_out("# log_file=%s" % log_file, log_f) avg_step_time = 0.0 # Create the configurations for the sessions config_proto = utils.get_config_proto(log_device_placement=log_device_placement) # Create three sessions, one for each model train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph) eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph) infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph) # Load the train model from checkpoint or create a new one with train_model.graph.as_default(): loaded_train_model, global_step = model_helper.create_or_load_model(train_model.model, model_dir, train_sess, name="train") # Summary writer summary_writer = tf.summary.FileWriter( os.path.join(out_dir, summary_name), train_model.graph) # First evaluation run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) last_stats_step = global_step last_eval_step = global_step last_external_eval_step = global_step # This is the training loop. # Initialize the hyperparameters for the loop. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 speed, train_ppl = 0.0, 0.0 start_train_time = time.time() utils.print_out( "# Start step %d, lr %g, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), time.ctime()), log_f) # epoch_step records where we were within an epoch. Used to skip trained on examples skip_count = hparams.batch_size * hparams.epoch_step utils.print_out("# Init train iterator, skipping %d elements" % skip_count) # Initialize the training iterator train_sess.run( train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: skip_count}) # Train until we reach num_steps. while global_step < num_train_steps: # Run a step start_step_time = time.time() try: step_result = loaded_train_model.train(train_sess) (_, step_loss, step_predict_count, step_summary, global_step, # The _ is the output of the update op step_word_count, batch_size) = step_result hparams.epoch_step += 1 except tf.errors.OutOfRangeError: # Finished going through the training dataset. Go to next epoch. hparams.epoch_step = 0 utils.print_out( "# Finished an epoch, step %d. Perform external evaluation" % global_step) # Decode and print a random sentence run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Perform external evaluation to save checkpoints if this is the best for some metric dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams, summary_writer, save_on_best_dev=True) # Reinitialize the iterator from the beginning train_sess.run(train_model.iterator.initializer, feed_dict={train_model.skip_count_placeholder: 0}) continue # Write step summary. summary_writer.add_summary(step_summary, global_step) # update statistics step_time += (time.time() - start_step_time) checkpoint_loss += (step_loss * batch_size) checkpoint_predict_count += step_predict_count checkpoint_total_count += float(step_word_count) # Once in a while, we print statistics. if global_step - last_stats_step >= steps_per_stats: last_stats_step = global_step # Print statistics for the previous epoch. avg_step_time = step_time / steps_per_stats train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count) speed = checkpoint_total_count / (1000 * step_time) utils.print_out( " global step %d lr %g " "step-time %.2fs wps %.2fK ppl %.2f %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, _get_best_results(hparams)), log_f) if math.isnan(train_ppl): # The model has screwed up break # Reset timer and loss. step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0 checkpoint_total_count = 0.0 if global_step - last_eval_step >= steps_per_eval: # Perform evaluation. Start by reassigning the last_eval_step variable to the current step last_eval_step = global_step # Print the progress and add summary utils.print_out("# Save eval, global step %d" % global_step) utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl) # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) # Decode and print a random sample run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Run internal evaluation, and update the ppl variables. The data iterator is instantieted in the method. dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer) if global_step - last_external_eval_step >= steps_per_external_eval: # Run the external evaluation last_external_eval_step = global_step # Save checkpoint loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) # Decode and print a random sample run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer, sample_src_data, sample_tgt_data) # Run external evaluation, updating metric scores in the meanwhile. The unneeded output is the global step. dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams, summary_writer, save_on_best_dev=True) # Done training. Save the model loaded_train_model.saver.save( train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step) result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval( model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out( "# Final, step %d lr %g " "step-time %.2f wps %.2fK ppl %.2f, %s, %s" % (global_step, loaded_train_model.learning_rate.eval(session=train_sess), avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f) utils.print_time("# Done training!", start_train_time) utils.print_out("# Start evaluating saved best models.") for metric in hparams.metrics: best_model_dir = getattr(hparams, "best_" + metric + "_dir") result_summary, best_global_step, _, _, _, _ = run_full_eval( best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams, summary_writer, sample_src_data, sample_tgt_data) utils.print_out("# Best %s, step %d " "step-time %.2f wps %.2fK, %s, %s" % (metric, best_global_step, avg_step_time, speed, result_summary, time.ctime()), log_f) summary_writer.close() return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)