def do_SL_training(self, model, sess, global_step, train_handle):
        # first do initialization
        if not self.train_it_initialized:
            sess.run(model.train_iterator.initializer,
                     feed_dict={model.skip_count_placeholder: 0})
            self.train_it_initialized = True

        # second, do training
        while True:  # keep tring until no exception
            try:
                res = model.model.self_play_train(sess, train_handle)
                break
            except tf.errors.OutOfRangeError:
                sess.run(model.train_iterator.initializer,
                         feed_dict={model.skip_count_placeholder: 0})
                continue

        all_summaries = res[-1]
        if self.summary_writer:
            for key in all_summaries:
                utils.add_summary(self.summary_writer, global_step,
                                  self.dialogue_mode + '_' + key,
                                  all_summaries[key])
        global_step = res[-2]
        self.num_sl_updates += 1
        return global_step
    def do_rl_training(self, data, kb, batch_size, model, sess, speaker,
                       global_step, self_play_handle):
        if self.iterator_mode == 1:
            self_play_iterator = model.self_play_ft_iterator
        elif self.iterator_mode == 2:
            self_play_iterator = model.self_play_st_iterator
        else:
            raise Exception('not defined self_play_mode')

        # first do initialization
        sess.run(self_play_iterator.initializer,
                 feed_dict={
                     model.data_placeholder: data,
                     model.kb_placeholder: kb,
                     model.batch_size_placeholder: batch_size
                 })
        # second, do training
        res = model.model.self_play(sess, speaker, self_play_handle)
        all_summaries = res[-1]
        if self.summary_writer:
            for key in all_summaries:
                utils.add_summary(self.summary_writer, global_step,
                                  self.dialogue_mode + '_' + key,
                                  all_summaries[key])
        global_step = res[2]
        self.num_rl_updates += 1
        return global_step
Exemple #3
0
def _internal_eval(model, global_step, sess, iterator, iterator_feed_dict,
                   summary_writer, label):
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
    perplexity = model_util.compute_perplexity(model, sess, label)
    utils.add_summary(summary_writer, global_step, "%s_perplexity" % label,
                      perplexity)
    return perplexity
Exemple #4
0
def _internal_eval(model, global_step, sess, iterator, iterator_feed_dict,
                   summary_writer, label):
    """Computing perplexity."""
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
    ppl = model_helper.compute_perplexity(model, sess, label)
    utils.add_summary(summary_writer, global_step, "%s_ppl" % label, ppl)
    return ppl
Exemple #5
0
def _external_eval(model,
                   global_step,
                   sess,
                   hparams,
                   iterator,
                   iterator_feed_dict,
                   tgt_file,
                   label,
                   summary_writer,
                   save_on_best,
                   avg_ckpts=False):
    """External evaluation such as BLEU and ROUGE scores."""
    out_dir = hparams.out_dir
    decode = global_step > 0

    if avg_ckpts:
        label = "avg_" + label

    if decode:
        utils.print_out("# External evaluation, global step %d" % global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    output = os.path.join(out_dir, "output_%s" % label)
    scores = nmt_utils.decode_and_evaluate(
        label,
        model,
        sess,
        output,
        ref_file=tgt_file,
        metrics=hparams.metrics,
        subword_option=hparams.subword_option,
        beam_width=hparams.beam_width,
        tgt_eos=hparams.eos,
        decode=decode,
        infer_mode=hparams.infer_mode)
    # Save on best metrics
    if decode:
        for metric in hparams.metrics:
            if avg_ckpts:
                best_metric_label = "avg_best_" + metric
            else:
                best_metric_label = "best_" + metric

            utils.add_summary(summary_writer, global_step,
                              "%s_%s" % (label, metric), scores[metric])
            # metric: larger is better
            if save_on_best and scores[metric] > getattr(
                    hparams, best_metric_label):
                setattr(hparams, best_metric_label, scores[metric])
                model.saver.save(sess,
                                 os.path.join(
                                     getattr(hparams,
                                             best_metric_label + "_dir"),
                                     "translate.ckpt"),
                                 global_step=model.global_step)
        utils.save_hparams(out_dir, hparams)
    return scores
    def talk(self,
             max_diag_length,
             batch_input_data,
             batch_input_kb,
             agent1,
             agent2,
             worker_step,
             batch_size,
             speaker=None):
        """The main procedure to generate a single self play conversation."""
        # parse data
        bs_intent, bs_truth_action, bs_kb = self.parse_input(
            batch_input_data, batch_input_kb)
        # remember the roles of agents
        self.agents = [agent1, agent2]
        # In selfplay training the speaker will be non and we randomly chose an
        # initial speaker and initialize utterance.
        # In selfplay evaluation the speaker will be specified so we use as is
        if not speaker: speaker = int(np.random.random() < 0.5)
        # generate the conversation instance for this conversation.
        # print ('self.batch_size', self.batch_size)
        conv = Conversation(max_diag_length, self.turn1_token,
                            self.turn2_token, batch_size, speaker)

        # generate conversation by turn in batch mode until all conversations
        # terminated (finished = True) or the number of turns reached the maximum.
        turn = 0
        finished = False
        while (not finished) and turn < self.max_dialogue_turns:
            finished = self.generate_utterance(bs_intent, conv, bs_kb, speaker,
                                               turn, batch_size)
            #  Change the speaker as we move to the next turn.
            speaker = (speaker + 1) % 2
            turn += 1

        all_rewards = dialogue_utils.compute_reward_batch(
            conv.utt_arr, conv.action_arr, bs_truth_action, bs_kb,
            self.hparams)
        metrics = dialogue_utils.calculate_reward_metrics(all_rewards)
        metrics['num_turns'] = turn

        #  print out step stats only in debug mode
        if self.summary_writer and self.hparams.debug:
            for key in metrics:
                utils.add_summary(self.summary_writer, worker_step,
                                  self.dialogue_mode + '_' + key + '_ws',
                                  metrics[key])

        utt_arr, bs_pred_action = conv.get_train_data()

        if self.hparams.debug:
            print('self_play debug: ' + bs_intent[0])
            print('self_play debug: all_rewards', all_rewards[0])
            print('self_play debug: ' + ' '.join(utt_arr[0]))
            print('self_play debug: ' + ' '.join(bs_pred_action[0]))
            sys.stdout.flush()
        return (bs_intent, bs_pred_action, bs_truth_action, utt_arr,
                bs_kb), turn, metrics
Exemple #7
0
def _internal_eval(model, global_step, sess, iterator, iterator_feed_dict, summary_writer, label):
    """Used to compute perplexity on the dataset provided through the iterator"""
    # Initialize the iterator using the feed dict
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
    # Compute the perplexity
    ppl = model_helper.compute_perplexity(model, sess, name=label)
    # Add summary for the ppl to the summary writer
    utils.add_summary(summary_writer, global_step, tag="%s_ppl" % label, value=ppl)

    return ppl
Exemple #8
0
def _internal_eval(model, global_step, sess, real_iterator, iterator_feed_dict,
                   iterator_handle, summary_writer, label):
  """Computing perplexity."""
  sess.run(real_iterator.initializer, feed_dict=iterator_feed_dict)
  ppl, all_summaries = model_helper.compute_perplexity(model, sess, label,
                                                       iterator_handle)

  utils.add_summary(summary_writer, global_step, "%s_ppl" % label, ppl)
  for key in all_summaries:
    utils.add_summary(summary_writer, global_step, key, all_summaries[key])
  return ppl
def handle_summary(diag_mode, summary_writer, global_step, all_summary,
                   summary_weight):
    """hanel all summary and combine them together."""
    combined = {}
    for summary in all_summary:
        for key in summary:
            if key not in combined:
                combined[key] = []
            combined[key].append(summary[key])
    print('combined', combined)
    for key in combined:
        combined[key] = np.average(combined[key], weights=summary_weight)
        name = diag_mode + '_' + key
        utils.add_summary(summary_writer, global_step, name, combined[key])
Exemple #10
0
def external_eval(model,
                  global_step,
                  sess,
                  hparams,
                  iterator,
                  iterator_feed_dict,
                  tgt_file,
                  label,
                  summary_writer,
                  save_on_best,
                  avg_ckpts=False):
    """
    External evaluation such as BLEU and ROUGE scores.
    """

    out_dir = hparams.out_dir
    if avg_ckpts:
        label = "avg_" + label

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    output = os.path.join(out_dir, "output_%s" % label)
    scores = model.decode_and_evaluate(label,
                                       sess,
                                       output,
                                       ref_file=tgt_file,
                                       beam_width=hparams.beam_width,
                                       tgt_eos=hparams.eos)

    # Save on best metrics
    if avg_ckpts:
        best_metric_label = "avg_best_bleu"
    else:
        best_metric_label = "best_bleu"
    utils.add_summary(summary_writer, global_step, "%s_bleu" % (label, ),
                      scores['BLEU'])
    # metric: larger is better
    if save_on_best and scores['BLEU'] > getattr(hparams, best_metric_label):
        setattr(hparams, best_metric_label, scores['BLEU'])
        model.saver.save(sess,
                         os.path.join(
                             getattr(hparams, best_metric_label + "_dir"),
                             "translate.ckpt"),
                         global_step=model.global_step)

    utils.save_hparams(out_dir, hparams)

    return scores
Exemple #11
0
def _external_eval(model, global_step, sess, hparams, iterator,
                   iterator_feed_dict, tgt_file, label, summary_writer,
                   save_on_best_dev):
    """External evaluation such as BLEU and ROUGE scores. If save on best then keep the best scores in the hparams"""
    out_dir = hparams.out_dir
    # Avoids running eval when global step is 0
    decode = global_step > 0
    if decode:
        utils.print_out("# External evaluation, global step %d" % global_step)
    # Initialize the iterator
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
    # Create the output file for the logs
    output_file = os.path.join(out_dir, "output_%s" % label)
    # Get the scores for the metrics
    scores = chatbot_utils.decode_and_evaluate(
        name=label,
        model=model,
        sess=sess,
        output_file=output_file,
        reference_file=tgt_file,
        metrics=hparams.metrics,
        bpe_delimiter=hparams.bpe_delimiter,
        beam_width=hparams.beam_width,
        eos=hparams.eos,
        number_token=hparams.number_token,
        name_token=hparams.name_token,
        decode=decode
    )
    # Create the summaries and also save the best
    if decode:
        for metric in hparams.metrics:
            # Create the summary
            utils.add_summary(summary_writer, global_step, "%s_%s" % (label, metric),
                              scores[metric])
            # Is the current metric score better than the last
            if save_on_best_dev and scores[metric] > getattr(hparams, "best_" + metric):
                # Update the hparams score
                setattr(hparams, "best_" + metric, scores[metric])
                # Save the model which got the best for this metric to file
                model.saver.save(sess,
                                 os.path.join(getattr(hparams, "best_" + metric + "_dir"), "dialogue.ckpt"),
                                 global_step=model.global_step)  # For safety
    # Save the hparams to file
    utils.save_hparams(out_dir, hparams, verbose=True)

    return scores
Exemple #12
0
def single_worker_inference(infer_model, infer_sess, eval_model, eval_sess,
                            ckpt, summary_writer, global_step, hparams):
  """the actual function for inference."""
  # load datasets
  infer_src_data = load_data(hparams.infer_src_data)
  infer_tar_data = load_data(hparams.infer_tar_data)
  infer_kb = load_data(hparams.infer_kb)

  # load model and session
  start_time = time.time()
  with infer_model.graph.as_default():
    loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                 infer_sess, "infer")
    infer_sess.run(
        infer_model.infer_iterator.initializer,
        feed_dict={
            infer_model.data_src_placeholder: infer_src_data,
            infer_model.kb_placeholder: infer_kb,
            infer_model.batch_size_placeholder: hparams.infer_batch_size
        })
    infer_handle = infer_sess.run(infer_model.infer_iterator.string_handle())

    # Decode
    utils.print_out("# Start decoding")
    evaluation_scores = dialogue_utils.decode_and_evaluate(
        "infer",
        loaded_infer_model,
        infer_handle,
        infer_sess,
        hparams.inference_output_file,
        ref_file=hparams.infer_tar_data,
        metrics=hparams.metrics,
        hparams=hparams,
        infer_src_data=infer_src_data)
    # summary writer
    for key in evaluation_scores:
      # utils.add_summary(summary_writer,)
      utils.add_summary(summary_writer, global_step, key,
                        evaluation_scores[key])
    # sample some dialogue and decode them for qualitative examination
    _sample_decode(loaded_infer_model, global_step, infer_handle, infer_sess,
                   hparams, infer_model.infer_iterator, infer_src_data,
                   infer_tar_data, infer_kb, infer_model.data_src_placeholder,
                   infer_model.kb_placeholder,
                   infer_model.batch_size_placeholder)
  # run eval model to get perplexity
  eval_handle = eval_sess.run(eval_model.eval_iterator.string_handle())
  dev_ppl, _ = run_internal_eval(eval_model, eval_handle, eval_sess,
                                 hparams.out_dir, hparams, summary_writer)
  utils.add_summary(summary_writer, global_step, "dev_ppl", dev_ppl)
  total_inference_time = time.time() - start_time
  utils.add_summary(summary_writer, global_step, "infer_time",
                    total_inference_time)
Exemple #13
0
def train(hparams, scope=None):

    model_dir = hparams.out_dir
    avg_ckpts = hparams.avg_ckpts
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval
    summary_name = "summary"

    model_creator = gnmt_model.GNMTModel
    train_model = model_helper.create_train_model(model_creator, hparams)
    eval_model = model_helper.create_eval_model(model_creator, hparams)
    infer_model = model_helper.create_infer_model(model_creator, hparams)

    config_proto = tf.ConfigProto()
    config_proto.gpu_options.allow_growth = True
    train_sess = tf.Session(graph=train_model.graph, config=config_proto)
    eval_sess = tf.Session(graph=eval_model.graph, config=config_proto)
    infer_sess = tf.Session(graph=infer_model.graph, config=config_proto)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(model_dir, summary_name), train_model.graph)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = utils.load_data(dev_src_file)
    sample_tgt_data = utils.load_data(dev_tgt_file)

    # First evaluation
    result_summary, _, _ = run_full_eval(model_dir, infer_model, infer_sess,
                                         eval_model, eval_sess, hparams,
                                         summary_writer, sample_src_data,
                                         sample_tgt_data, avg_ckpts)
    utils.log('First evaluation: {}'.format(result_summary))

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats = init_stats()
    info = {
        "train_ppl": 0.0,
        "speed": 0.0,
        "avg_step_time": 0.0,
        "avg_grad_norm": 0.0,
        "learning_rate":
        loaded_train_model.learning_rate.eval(session=train_sess)
    }
    utils.log("Start step %d, lr %g" % (global_step, info["learning_rate"]))

    # Initialize all of the iterators
    train_sess.run(train_model.iterator.initializer)

    epoch = 1

    while True:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            utils.log(
                "Finished epoch %d, step %d. Perform external evaluation" %
                (epoch, global_step))
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)
            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)
            train_sess.run(train_model.iterator.initializer)

            if epoch < hparams.epochs:
                epoch += 1
                continue
            else:
                break

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats)
            print_step_info("  ", global_step, info,
                            "BLEU %.2f" % (hparams.best_bleu, ))
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.log("Save eval, global step %d" % (global_step, ))
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(model_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(model_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)
            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(model_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)
    print_step_info("Final, ", global_step, info, result_summary)
    utils.log("Done training!")

    summary_writer.close()

    utils.log("Start evaluating saved best models.")
    best_model_dir = hparams.best_bleu_dir
    summary_writer = tf.summary.FileWriter(
        os.path.join(best_model_dir, summary_name), infer_model.graph)
    result_summary, best_global_step, _ = run_full_eval(
        best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
        hparams, summary_writer, sample_src_data, sample_tgt_data)
    print_step_info("Best BLEU, ", best_global_step, info, result_summary)
    summary_writer.close()

    if avg_ckpts:
        best_model_dir = hparams.avg_best_bleu_dir
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("Averaged Best BLEU, ", best_global_step, info,
                        result_summary)
        summary_writer.close()

    return final_eval_metrics, global_step
Exemple #14
0
def train(hparams, identity, scope=None, target_session=""):
  """main loop to train the dialogue model. identity is used."""
  out_dir = hparams.out_dir
  steps_per_stats = hparams.steps_per_stats
  steps_per_internal_eval = 3 * steps_per_stats

  model_creator = diag_model.Model

  train_model = model_helper.create_train_model(model_creator, hparams, scope)

  model_dir = hparams.out_dir

  # Log and output files
  log_file = os.path.join(out_dir, identity+"log_%d" % time.time())
  log_f = tf.gfile.GFile(log_file, mode="a")
  utils.print_out("# log_file=%s" % log_file, log_f)

  avg_step_time = 0.0

  # load TensorFlow session and model
  config_proto = utils.get_config_proto(
      log_device_placement=hparams.log_device_placement,
      allow_soft_placement=True)

  train_sess = tf.Session(
      target=target_session, config=config_proto, graph=train_model.graph)

  train_handle = train_sess.run(train_model.train_iterator.string_handle())

  with train_model.graph.as_default():
    loaded_train_model, global_step = model_helper.create_or_load_model(
        train_model.model, model_dir, train_sess, "train")

  # initialize summary writer
  summary_writer = tf.summary.FileWriter(
      os.path.join(out_dir, "train_log"), train_model.graph)

  last_stats_step = global_step
  last_eval_step = global_step

  # initialize training stats.
  step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
  checkpoint_total_count = 0.0
  speed, train_ppl = 0.0, 0.0
  start_train_time = time.time()

  utils.print_out(
      "# Start step %d, lr %g, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       time.ctime()),
      log_f)

  # initialize iterators
  skip_count = hparams.batch_size * hparams.epoch_step
  utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
  train_sess.run(
      train_model.train_iterator.initializer,
      feed_dict={train_model.skip_count_placeholder: skip_count})

  # main training loop
  while global_step < hparams.num_train_steps:
    start_time = time.time()
    try:  #  run a step
      step_result = loaded_train_model.train(train_sess, train_handle)
      (_, step_loss, all_summaries, step_predict_count, step_summary,
       global_step, step_word_count, batch_size, _, _, words1, words2, mask1,
       mask2) = step_result
      hparams.epoch_step += 1

    except tf.errors.OutOfRangeError:  # finished an epoch
      hparams.epoch_step = 0
      utils.print_out("# Finished an epoch, step %d." % global_step)
      train_sess.run(
          train_model.train_iterator.initializer,
          feed_dict={train_model.skip_count_placeholder: 0})
      continue

    # Write step summary.
    summary_writer.add_summary(step_summary, global_step)
    for key in all_summaries:
      utils.add_summary(summary_writer, global_step, key, all_summaries[key])

    # update statistics
    step_time += (time.time() - start_time)

    checkpoint_loss += (step_loss * batch_size)
    checkpoint_predict_count += step_predict_count
    checkpoint_total_count += float(step_word_count)

    if global_step - last_stats_step >= steps_per_stats:
      # print statistics for the previous epoch and save the model.
      last_stats_step = global_step

      avg_step_time = step_time / steps_per_stats
      utils.add_summary(summary_writer, global_step, "step_time", avg_step_time)
      train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
      speed = checkpoint_total_count / (1000 * step_time)
      if math.isnan(train_ppl):
        break

      # Reset timer and loss.
      step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
      checkpoint_total_count = 0.0

      # save the model
      loaded_train_model.saver.save(
          train_sess,
          os.path.join(out_dir, "dialogue.ckpt"),
          global_step=global_step)

      # print the dialogue if in debug mode
      if hparams.debug:
        utils.print_current_dialogue(words1, words2, mask1, mask2)

    #  write out internal evaluation
    if global_step - last_eval_step >= steps_per_internal_eval:
      last_eval_step = global_step

      utils.print_out("# Internal Evaluation. global step %d" % global_step)
      utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

  # finished training
  loaded_train_model.saver.save(
      train_sess,
      os.path.join(out_dir, "dialogue.ckpt"),
      global_step=global_step)
  result_summary = ""
  utils.print_out(
      "# Final, step %d lr %g "
      "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
      (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
       avg_step_time, speed, train_ppl, result_summary, time.ctime()),
      log_f)
  utils.print_time("# Done training!", start_train_time)
  utils.print_out("# Start evaluating saved best models.")
  summary_writer.close()
Exemple #15
0
def add_info_summaries(summary_writer, global_step, info):
    """Add stuffs in info to summaries."""
    excluded_list = ["learning_rate"]
    for key in info:
        if key not in excluded_list:
            utils.add_summary(summary_writer, global_step, key, info[key])
Exemple #16
0
def train(args):
    start_time = time.time()
    vocab, model_config = load_vocab_setup_config(args)
    print('... load vocab and setup model config over, cost:\t%.2f s' %
          (time.time() - start_time))
    print('... vocab size:\t%d' % vocab.size)

    start_time = time.time()
    train_iter, valid_iter, test_iter = get_train_iter(args.data_dir,
                                                       vocab=vocab,
                                                       config=model_config)
    print('-' * 100)
    print('... load train and valid data iterator over, cost:\t%.2f s' %
          (time.time() - start_time))
    print('... train iterator samples:\t%d' % train_iter.num_samples)
    print('... valid iterator samples:\t%d' % valid_iter.num_samples)
    print('... test iterator samples:\t%d' % test_iter.num_samples)

    # prepare output dir
    output_dir = args.output_dir
    ckpt_dir = os.path.join(output_dir, "checkpoints")
    log_dir = os.path.join(output_dir, "train_log")
    pred_dir = os.path.join(output_dir, "pred")

    utils.mkdir_not_exists(output_dir)
    utils.mkdir_not_exists(ckpt_dir)
    utils.mkdir_not_exists(log_dir)
    utils.mkdir_not_exists(pred_dir)
    ckpt_path = os.path.join(ckpt_dir, model_config.model)

    print('=' * 100)
    print('... building model')
    start_time = time.time()
    if model_config.model == 'HRED':
        model = HREDModel
    else:
        raise NotImplementedError("No such model")

    config_proto = tf.ConfigProto()
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    with tf.Session(config=config_proto) as sess:
        initializer = tf.random_uniform_initializer(-1.0 * model_config.init_w,
                                                    model_config.init_w)
        scope = model_config.model
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            train_model = model(config=model_config,
                                mode=ModelMode.train,
                                scope=scope)

        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            eval_model = model(config=model_config,
                               mode=ModelMode.eval,
                               scope=scope)

        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            infer_model = model(config=model_config,
                                mode=ModelMode.infer,
                                scope=scope)

    print('... create %s model over, time cost: %.2fs' %
          (model_config.model, time.time() - start_time))
    print('=' * 100)
    sess.run(tf.global_variables_initializer())

    # Summary writer
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

    start_time = time.time()
    ckpt_loss, ckpt_ppl, ckpt_predict_count, ckpt_samples = 0.0, 0.0, 0, 0
    for epoch in range(model_config.max_epoch):
        train_samples = 0
        for batch_data in train_iter.next_batch():
            # print("batch data shape:", batch_data.dialog.shape, "batch length shape:", batch_data.dialog_length.shape)
            # print(batch_data.dialog_length)
            step_loss, step_ppl, step_predict_count, batch_size, step_summary, global_step = \
                train_model.train(sess, batch_data)

            ckpt_samples += batch_size
            ckpt_loss += step_loss * batch_size
            ckpt_ppl += step_ppl
            ckpt_predict_count += step_predict_count
            train_samples += batch_size

            summary_writer.add_summary(step_summary, global_step)

            if global_step % model_config.display_frequency == 0:
                train_loss = ckpt_loss / ckpt_samples
                train_ppl = safe_exp(ckpt_ppl / ckpt_predict_count)
                print(
                    'Epoch: %d/%d; Samples: %d/%d; Step: %d; Train Loss: %.2f; Train PPL: %.2f; Time Cost: %.2fs'
                    % (epoch + 1, model_config.max_epoch, train_samples,
                       train_iter.num_samples, global_step, train_loss,
                       train_ppl, time.time() - start_time))

                utils.add_summary(summary_writer, global_step, "train_ppl",
                                  train_ppl)

                ckpt_loss, ckpt_ppl, ckpt_predict_count, ckpt_samples = 0.0, 0.0, 0, 0
                start_time = time.time()

            if global_step % model_config.checkpoint_frequency == 0:
                print("--------- evaluate model ------------")
                start_time = time.time()
                valid_loss, valid_ppl, valid_bleu = run_evaluate(
                    sess, eval_model, infer_model, vocab, model_config,
                    global_step, args.data_dir, pred_dir, valid_iter, "valid")

                print(
                    'Epoch: %d/%d; Step: %d; Valid Loss: %.2f; Valid PPL: %.2f; Valid Bleu:%.2f; Time Cost: %.2fs'
                    % (epoch + 1, model_config.max_epoch,
                       global_step, valid_loss, valid_ppl, valid_bleu,
                       time.time() - start_time))

                start_time = time.time()

                test_loss, test_ppl, test_bleu = run_evaluate(
                    sess, eval_model, infer_model, vocab, model_config,
                    global_step, args.data_dir, pred_dir, test_iter, "test")

                print(
                    'Epoch: %d/%d; Step: %d; Test Loss: %.2f; Test PPL: %.2f; Test Bleu:%.2f; Time Cost: %.2fs'
                    %
                    (epoch + 1, model_config.max_epoch, global_step, test_loss,
                     test_ppl, test_bleu, time.time() - start_time))

                # save summary and checkpoints
                utils.add_summary(summary_writer, global_step, "valid_ppl",
                                  valid_ppl)
                utils.add_summary(summary_writer, global_step, "valid_bleu",
                                  valid_bleu)
                utils.add_summary(summary_writer, global_step, "test_ppl",
                                  test_ppl)
                utils.add_summary(summary_writer, global_step, "test_bleu",
                                  test_bleu)
                summary_writer.flush()

                train_model.save(sess, ckpt_path)

                start_time = time.time()

    # done training
    summary_writer.close()

    pass
Exemple #17
0
def train(hps, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hps.log_device_placement
    out_dir = hps.out_dir
    num_train_steps = hps.num_train_steps
    steps_per_stats = hps.steps_per_stats
    steps_per_external_eval = hps.steps_per_external_eval
    steps_per_eval = 100 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hps.attention_architecture == "baseline":
        model_creator = AttentionModel
    else:
        model_creator = AttentionHistoryModel

    train_model = model_helper.create_train_model(model_creator, hps, scope)
    eval_model = model_helper.create_eval_model(model_creator, hps, scope)
    infer_model = model_helper.create_infer_model(model_creator, hps, scope)

    # Preload data for sample decoding.

    article_filenames = []
    abstract_filenames = []
    art_dir = hps.data_dir + '/article'
    abs_dir = hps.data_dir + '/abstract'
    for file in os.listdir(art_dir):
        if file.startswith(hps.dev_prefix):
            article_filenames.append(art_dir + "/" + file)
    for file in os.listdir(abs_dir):
        if file.startswith(hps.dev_prefix):
            abstract_filenames.append(abs_dir + "/" + file)
    # if random_decode:
    #     """if this is a random sampling process during training"""
    decode_id = random.randint(0, len(article_filenames) - 1)
    single_article_file = article_filenames[decode_id]
    single_abstract_file = abstract_filenames[decode_id]

    dev_src_file = single_article_file
    dev_tgt_file = single_abstract_file
    sample_src_data = inference_base_model.load_data(dev_src_file)
    sample_tgt_data = inference_base_model.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hps.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    # run_full_eval(
    #     model_dir, infer_model, infer_sess,
    #     eval_model, eval_sess, hps,
    #     summary_writer,sample_src_data,sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats = init_stats()
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hps.batch_size * hps.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})
    epoch_step = 0
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hps, summary_writer)
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary and accumulate statistics
        global_step = update_stats(stats, summary_writer, start_time,
                                   step_result)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = check_stats(stats, global_step, steps_per_stats, hps,
                                      log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "summarized.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hps,
                                                  summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "summarized.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hps,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hps, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "summarized.ckpt"),
                                  global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hps,
        summary_writer, sample_src_data, sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step,
         loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hps.metrics:
        best_model_dir = getattr(hps, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hps, summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out(
            "# Best %s, step %d "
            "step-time %.2f wps %.2fK, %s, %s" %
            (metric, best_global_step, avg_step_time, speed, result_summary,
             time.ctime()), log_f)
        summary_writer.close()

    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
Exemple #18
0
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 50 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 2 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats = init_stats()
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})

    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError as exp:
            #print(exp)
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            # run_sample_decode(infer_model, infer_sess,
            #                   model_dir, hparams, summary_writer, sample_src_data,
            #                   sample_tgt_data)
            # dev_scores, test_scores, _ = run_external_eval(
            #     infer_model, infer_sess, model_dir,
            #     hparams, summary_writer)
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            print(steps_per_stats)
            print(steps_per_eval)
            continue

        # Write step summary and accumulate statistics
        global_step = update_stats(stats, summary_writer, start_time,
                                   step_result)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = check_stats(stats, global_step, steps_per_stats,
                                      hparams, log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hparams,
                                                  summary_writer)

            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step
            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step,
         loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()), log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out(
            "# Best %s, step %d "
            "step-time %.2f wps %.2fK, %s, %s" %
            (metric, best_global_step, avg_step_time, speed, result_summary,
             time.ctime()), log_f)
        summary_writer.close()

    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
def single_worker_selfplay(mutable_model, immutable_model, mutable_sess,
                           immutable_sess, selfplay_data_file,
                           selfplay_kb_file, global_step, hparams,
                           summary_writer):
    """selfplay with a single worker.

  This is preminarily used for self play
  evaluation.
  """

    dialogue_mode = dialogue_utils.mode_self_play_dialogue_eval
    # Read self play data
    selfplay_data = dialogue_utils.load_data(selfplay_data_file)
    selfplay_kb = dialogue_utils.load_data(selfplay_kb_file)

    # construct dialogue object
    dialogue = SelfplayDialogue(mutable_model,
                                immutable_model,
                                mutable_sess,
                                immutable_sess,
                                hparams.max_dialogue_turns,
                                hparams.train_threadhold,
                                hparams.start_of_turn1,
                                hparams.start_of_turn2,
                                hparams.end_of_dialogue,
                                summary_writer=summary_writer,
                                dialogue_mode=dialogue_mode,
                                hparams=hparams)

    batch_size = dialogue.self_play_eval_batch_size
    assert batch_size <= len(selfplay_data)

    loaded_mutable, _ = load_self_play_model(dialogue.mutable_model,
                                             dialogue.mutable_sess, 'mutable',
                                             hparams.self_play_pretrain_dir,
                                             hparams.out_dir)
    loaded_immutable, _ = load_self_play_model(dialogue.immutable_model,
                                               dialogue.immutable_sess,
                                               'immutable',
                                               hparams.self_play_pretrain_dir,
                                               hparams.out_dir)
    worker_step = 0
    all_summary = []
    summary_weight = []  # used in combination with all_summary

    # max_eval_per_flip = 100000
    # We flip the role of the agent for exactly two times. In the first iteration
    # when flip = 0, mutable model will be agent 1 and immutable model will be
    # agent 2. The other way around when flip = 1.
    start_time = time.time()
    num_flips_for_initial_speaker = 2
    with tf.gfile.GFile(hparams.selfplay_eval_output_file,
                        'w') as selfplay_out:
        print('flip 1')
        for flip in range(num_flips_for_initial_speaker):
            # epoch = -1
            i = len(selfplay_data)  # force shuffling at the beginning
            agent1, agent2, _ = dialogue.flip_agent(
                (loaded_mutable, mutable_sess, dialogue.mutable_handles),
                (loaded_immutable, immutable_sess, dialogue.immutable_handles),
                flip)
            # only eval one epoch
            # while epoch <= 0:
            # print(i, max_eval_per_flip)
            # if i * batch_size >= len(selfplay_data):  # reacehd the end
            input_data = zip(selfplay_data, selfplay_kb)
            # we don't shuffle in evaluation
            # random.shuffle(input_data)  # random shuffle input data
            # i = 0
            selfplay_data, selfplay_kb = zip(*input_data)
            # epoch += 1
            ceil = int(math.ceil(len(selfplay_data) * 1.0 / batch_size))
            for i in tqdm(range(0, ceil)):
                start_ind = i * batch_size
                end_ind = min(i * batch_size + batch_size, len(selfplay_data))

                batch_data = selfplay_data[start_ind:end_ind]
                batch_kb = selfplay_kb[start_ind:end_ind]
                # we indicaet to let agent1 to talk first. Keep in mind that we will
                # swap between agent1 and agent2.
                speaker = flip % 2
                generated_data, _, summary = dialogue.talk(
                    hparams.max_dialogue_len, batch_data, batch_kb, agent1,
                    agent2, worker_step, batch_size, speaker)
                output_generated_data(generated_data, selfplay_out)
                all_summary.append(summary)
                # number of elements processed
                summary_weight.append(end_ind - start_ind)
                worker_step += 1
                # i += batch_size
    handle_summary(dialogue_mode, summary_writer, global_step, all_summary,
                   summary_weight)
    end_time = time.time()
    print('finished')
    utils.add_summary(summary_writer, global_step, dialogue_mode + '_time',
                      end_time - start_time)  #  step wise summary
def multi_worker_selfplay(hparams,
                          identity,
                          scope=None,
                          target_session='',
                          is_chief=True,
                          ps_tasks=0,
                          num_workers=1,
                          jobid=0,
                          startup_delay_steps=0):
    """This is the multi worker selfplay, mostly used for self play

  distributed training.
  identity is used.
  """
    immutable_model_reload_freq = hparams.immutable_model_reload_freq
    # 1. models and summary writer
    model_creator = diag_model.Model
    extra_args = model_helper.ExtraArgs(
        single_cell_fn=None,
        model_device_fn=tf.train.replica_device_setter(ps_tasks),
        attention_mechanism_fn=None)

    mutable_model = model_helper.create_selfplay_model(model_creator,
                                                       is_mutable=True,
                                                       num_workers=num_workers,
                                                       jobid=jobid,
                                                       hparams=hparams,
                                                       scope=scope,
                                                       extra_args=extra_args)
    immutable_hparams = copy.deepcopy(hparams)
    immutable_hparams.num_gpus = 0
    immutable_model = model_helper.create_selfplay_model(
        model_creator,
        is_mutable=False,
        num_workers=num_workers,
        jobid=jobid,
        hparams=immutable_hparams,
        scope=scope)

    if hparams.self_play_immutable_gpu:
        print('using GPU for immutable')
        immutable_sess = tf.Session(
            graph=immutable_model.graph,
            config=tf.ConfigProto(allow_soft_placement=True))
    else:
        print('not using GPU for immutable')
        immutable_sess = tf.Session(graph=immutable_model.graph,
                                    config=tf.ConfigProto(
                                        allow_soft_placement=True,
                                        device_count={'GPU': 0}))

    immutable_model, immutable_sess = load_self_play_model(
        immutable_model, immutable_sess, 'immutable',
        hparams.self_play_pretrain_dir, hparams.out_dir)
    global_step = immutable_model.model.global_step.eval(
        session=immutable_sess)

    if is_chief:
        ckpt = tf.train.latest_checkpoint(hparams.out_dir)
        if not ckpt:
            print('global_step, saving pretrain model to hparams.out_dir',
                  global_step, hparams.out_dir)
            immutable_model.model.saver.save(  # this is the prevent adam error
                immutable_sess,
                os.path.join(hparams.out_dir, 'dialogue.ckpt'),
                global_step=global_step)
            print('save finished')

    if is_chief:
        summary_writer_path = os.path.join(
            hparams.out_dir, identity + task_SP_DISTRIBUTED + '_log')
        summary_writer = tf.summary.FileWriter(summary_writer_path,
                                               mutable_model.graph)
        print('summary writer established at', summary_writer_path)
    else:
        summary_writer = None
    # 2. supervisor and sessions

    sv = tf.train.Supervisor(
        graph=mutable_model.graph,
        is_chief=is_chief,
        saver=mutable_model.model.saver,
        save_model_secs=0,  # disable automatic save checkpoints
        summary_op=None,
        logdir=hparams.out_dir,
        checkpoint_basename='dialogue.ckpt')

    mutable_config = utils.get_config_proto(
        log_device_placement=hparams.log_device_placement,
        allow_soft_placement=True)
    mutable_config.device_count['GPU'] = hparams.num_gpus

    mutable_sess = sv.prepare_or_wait_for_session(target_session,
                                                  config=mutable_config)

    # 3. additiona preparations
    global_step = mutable_model.model.global_step.eval(session=mutable_sess)
    while global_step < (jobid * (jobid + 1) * startup_delay_steps / 2):
        time.sleep(1)
        global_step = mutable_model.model.global_step.eval(
            session=mutable_sess)

    # save first model
    if is_chief:
        print('saveing the first checkpoint to', hparams.out_dir)
        mutable_model.model.saver.save(mutable_sess,
                                       os.path.join(hparams.out_dir,
                                                    'dialogue.ckpt'),
                                       global_step=global_step)
        last_save_step = global_step

    # Read data
    selfplay_data = dialogue_utils.load_data(hparams.self_play_train_data)
    selfplay_kb = dialogue_utils.load_data(hparams.self_play_train_kb)

    dialogue = SelfplayDialogue(mutable_model,
                                immutable_model,
                                mutable_sess,
                                immutable_sess,
                                hparams.max_dialogue_turns,
                                hparams.train_threadhold,
                                hparams.start_of_turn1,
                                hparams.start_of_turn2,
                                hparams.end_of_dialogue,
                                summary_writer=summary_writer,
                                dialogue_mode=task_SP_DISTRIBUTED,
                                hparams=hparams)

    # 4. main loop
    last_immmutable_model_reload = global_step
    last_save_step = global_step
    batch_size = dialogue.batch_size
    assert batch_size <= len(selfplay_data)

    # this is the start point of the self-play data. force shuffling at the beginning
    i = len(selfplay_data)
    train_stats = [0, 0]
    while global_step < hparams.num_self_play_train_steps:
        # a. reload immutable model, muttable will be automated managed by supervisor
        if immutable_model_reload_freq > 0 and global_step - last_immmutable_model_reload > immutable_model_reload_freq:
            immutable_model, immutable_sess = load_self_play_model(
                immutable_model, immutable_sess, 'immutable',
                hparams.self_play_pretrain_dir, hparams.out_dir)
            last_immmutable_model_reload = global_step
        # b. possiblely flip between speakers (or roll out models),
        # based on either a random policy or by step counts
        agent1, agent2, mutable_agent_index = dialogue.flip_agent(
            (mutable_model, mutable_sess, dialogue.mutable_handles),
            (immutable_model, immutable_sess, dialogue.immutable_handles))
        train_stats[mutable_agent_index] += 1
        # read selfplay data
        start_time = time.time()
        if i * batch_size + batch_size > len(selfplay_data):  # reacehd the end
            input_data = zip(selfplay_data, selfplay_kb)
            random.shuffle(input_data)  # random shuffle input data
            i = 0
            selfplay_data, selfplay_kb = zip(*input_data)

        start_ind, end_ind = i * batch_size, i * batch_size + batch_size
        batch_data, batch_kb = selfplay_data[start_ind:end_ind], selfplay_kb[
            start_ind:end_ind]
        train_example, _, _ = dialogue.talk(hparams.max_dialogue_len,
                                            batch_data, batch_kb, agent1,
                                            agent2, batch_size, global_step)
        possible_global_step = dialogue.maybe_train(train_example,
                                                    mutable_agent_index,
                                                    global_step,
                                                    force=True)
        if possible_global_step:
            global_step = possible_global_step
        if is_chief and global_step - last_save_step > hparams.self_play_dist_save_freq:
            mutable_model.model.saver.save(mutable_sess,
                                           os.path.join(
                                               hparams.out_dir,
                                               'dialogue.ckpt'),
                                           global_step=global_step)
            last_save_step = global_step
        end_time = time.time()

        if is_chief:
            utils.add_summary(summary_writer, global_step,
                              task_SP_DISTRIBUTED + '_' + 'time',
                              end_time - start_time)
            utils.add_summary(summary_writer, global_step,
                              task_SP_DISTRIBUTED + '_' + 'train_ratio',
                              train_stats[0] * 1.0 / (train_stats[1] + 0.1))
        i += 1

    if is_chief:
        summary_writer.close()

    mutable_sess.close()
    immutable_sess.close()
Exemple #21
0
    def maybe_train(self, sample, speaker, global_step, force=False):
        self.train_samples.append(sample)
        if force or len(self.train_samples) >= self.train_threadhold:
            # first generate training examples
            data_arr = []
            kb_arr = []
            for sample in self.train_samples:  # each sample is a batch of data
                intent, pred_action, truth_action, utterance, kb = sample  # batch version
                all_rewards = dialogue_utils.compute_reward_batch(
                    utterance, pred_action, truth_action, kb,
                    self.hparams)  # batch version
                train_reward, _, _, _, _, _, _, _, _ = all_rewards
                final_reward = train_reward
                reward_diag, reward_action = self.scale_reward_batch(
                    final_reward, self.gamma, utterance)  # in batches
                flat_pred_action = []
                for k in range(len(pred_action)):
                    flat_pred_action.append(' '.join(pred_action[k]))

                new_data_arr = self.format_samples_batch(
                    batch_intent=intent,
                    batch_pred_action=flat_pred_action,
                    batch_truth_action=truth_action,
                    batch_utterance=utterance,
                    batch_reward_diag=reward_diag,
                    batch_reward_action=reward_action,
                    batch_size=self.update_batch_size)
                data_arr.extend(new_data_arr)
                kb_arr.extend(kb)
            data_output, kb_output = data_arr, kb_arr
            new_global_step = None
            self.train_samples = []  # clean up
            self_play_hangle = self.mutable_handles[self.iterator_mode]
            if self.hparams.rl_training:
                new_global_step = self.do_rl_training(data_output, kb_output,
                                                      self.update_batch_size,
                                                      self.mutable_model,
                                                      self.mutable_sess,
                                                      speaker, global_step,
                                                      self_play_hangle)

            print('self.hparams.self_play_sl_multiplier=',
                  self.hparams.self_play_sl_multiplier)
            if self.hparams.self_play_sl_multiplier >= 0:  # train multiple or don't train at all
                print('do', self.hparams.self_play_sl_multiplier,
                      'supervised training')
                for _ in range(self.hparams.self_play_sl_multiplier):
                    new_global_step = self.do_SL_training(
                        self.mutable_model, self.mutable_sess, global_step,
                        self.mutable_handles[0])
            else:
                print('do one supervised traiing')
                if self.train_counter >= abs(
                        self.hparams.self_play_sl_multiplier):
                    new_global_step = self.do_SL_training(
                        self.mutable_model, self.mutable_sess, global_step,
                        self.mutable_handles[0])
                    self.train_counter = 0
                else:
                    self.train_counter += 1

            if self.summary_writer:
                utils.add_summary(
                    self.summary_writer, new_global_step,
                    self.dialogue_mode + '_' + 'sl_rl',
                    self.num_sl_updates * 1.0 / (self.num_rl_updates + 0.0001))

            return new_global_step
        return None
Exemple #22
0
def train(hparams, scope=None, target_session="", single_cell_fn=None):
    """Train a translation model."""

    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats

    if hparams.eval_on_fly:
        steps_per_external_eval = hparams.steps_per_external_eval
        steps_per_eval = 10 * steps_per_stats

        if not steps_per_external_eval:
            steps_per_external_eval = 2 * steps_per_eval
    else:
        steps_per_snapshot = hparams.snapshot_interval

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")

    train_model = create_train_model(model_creator, hparams, scope,
                                     single_cell_fn)

    if hparams.eval_on_fly:
        eval_model = create_eval_model(model_creator, hparams, scope,
                                       single_cell_fn)
        infer_model = inference.create_infer_model(model_creator, hparams,
                                                   scope, single_cell_fn)

        # Preload data for sample decoding.
        dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
        dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
        sample_src_data = inference.load_data(dev_src_file)
        sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement)

    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)

    if hparams.eval_on_fly:
        eval_sess = tf.Session(target=target_session,
                               config=config_proto,
                               graph=eval_model.graph)
        infer_sess = tf.Session(target=target_session,
                                config=config_proto,
                                graph=infer_model.graph)

    with train_model.graph.as_default():
        model_helper.initialize_cnn(train_model.model, train_sess)
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    if hparams.eval_on_fly:
        run_full_eval(model_dir, infer_model, infer_sess, eval_model,
                      eval_sess, hparams, summary_writer, sample_src_data,
                      sample_tgt_data)

    last_stats_step = global_step

    if hparams.eval_on_fly:
        last_eval_step = global_step
        last_external_eval_step = global_step
    else:
        last_snapshot_step = global_step

    # This is the training loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(
            session=train_sess), time.ctime()), log_f)

    # Initialize all of the iterators
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    train_sess.run(train_model.iterator.initializer,
                   feed_dict={train_model.skip_count_placeholder: skip_count})

    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            (_, step_loss, step_predict_count, step_summary, global_step,
             step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            if hparams.eval_on_fly:
                run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                                  summary_writer, sample_src_data,
                                  sample_tgt_data)
                dev_scores, test_scores, _ = run_external_eval(
                    infer_model, infer_sess, model_dir, hparams,
                    summary_writer)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss /
                                       checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g step-time %.2fs wps %.2fK ppl %.2f %s"
                %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)
            if math.isnan(train_ppl):
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        ##
        if (not hparams.eval_on_fly) and (global_step - last_snapshot_step >=
                                          steps_per_snapshot):
            last_snapshot_step = global_step
            utils.print_out("# Cihan: Saving Snapshot, global step %d" %
                            global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

        if hparams.eval_on_fly and (global_step - last_eval_step >=
                                    steps_per_eval):
            last_eval_step = global_step

            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess,
                                                  model_dir, hparams,
                                                  summary_writer)

        if hparams.eval_on_fly and (global_step - last_external_eval_step >=
                                    steps_per_external_eval):
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            dev_scores, test_scores, _ = run_external_eval(
                infer_model, infer_sess, model_dir, hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    if hparams.eval_on_fly:
        result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
            model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
            summary_writer, sample_src_data, sample_tgt_data)

        utils.print_out(
            "# Final, step %d lr %g step-time %.2f wps %.2fK ppl %.2f, %s, %s"
            % (global_step,
               loaded_train_model.learning_rate.eval(session=train_sess),
               avg_step_time, speed, train_ppl, result_summary, time.ctime()),
            log_f)
        utils.print_time("# Done training!", start_train_time)

        utils.print_out("# Start evaluating saved best models.")
        for metric in hparams.metrics:
            best_model_dir = getattr(hparams, "best_" + metric + "_dir")
            result_summary, best_global_step, _, _, _, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            utils.print_out(
                "# Best %s, step %d step-time %.2f wps %.2fK, %s, %s" %
                (metric, best_global_step, avg_step_time, speed,
                 result_summary, time.ctime()), log_f)

    summary_writer.close()

    if hparams.eval_on_fly:
        return dev_scores, test_scores, dev_ppl, test_ppl, global_step
    else:
        return global_step
Exemple #23
0
def _external_eval(model, global_step, sess, hparams, iterator,
                   iterator_feed_dict, tgt_file, label, summary_writer,
                   save_on_best):
    """External evaluation such as BLEU and ROUGE scores."""
    out_dir = hparams.out_dir
    decode = global_step > 0
    if decode:
        utils.print_out("# External evaluation, global step %d" % global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    output = os.path.join(out_dir, "output_%s" % label)

    scores = nmt_utils.decode_and_evaluate(label,
                                           model,
                                           sess,
                                           output,
                                           ref_file=tgt_file,
                                           metrics=hparams.metrics,
                                           bpe_delimiter=hparams.bpe_delimiter,
                                           beam_width=hparams.beam_width,
                                           tgt_eos=hparams.eos,
                                           decode=decode)
    # Save on best metrics
    if decode:
        for metric in hparams.metrics:
            utils.add_summary(summary_writer, global_step,
                              "%s_%s" % (label, metric), scores[metric])
            # metric: larger is better
            # if save_on_best and scores[metric] > getattr(hparams, "best_" + metric):
            with open("./tmp/nmt_model/score", 'w+') as resu:
                resu.write(str(global_step) + ":" + str(scores[metric]) + "\n")
            if save_on_best and scores[metric] > getattr(hparams,
                                                         "top_score")[0]:
                new_top_score = []
                new_top_score_name = []
                isTopScore = True
                for score, name in zip(getattr(hparams, "top_score"),
                                       getattr(hparams, "top_score_name")):
                    if scores[metric] < score and isTopScore:
                        new_top_score.append(scores[metric])
                        new_top_score_name.append(str(global_step))
                        isTopScore = False
                    new_top_score.append(score)
                    new_top_score_name.append(name)
                if isTopScore:
                    new_top_score.append(scores[metric])
                    new_top_score_name.append(str(global_step))
                setattr(hparams, "top_score", new_top_score[1:])
                setattr(hparams, "top_score_name", new_top_score_name[1:])
                setattr(hparams, "best_" + metric,
                        new_top_score[len(new_top_score) - 1])
                model.saver.save(sess,
                                 os.path.join(
                                     getattr(hparams,
                                             "best_" + metric + "_dir"),
                                     "translate.ckpt"),
                                 global_step=model.global_step)
                if new_top_score[0] != 0:
                    os.system('rm ' +
                              getattr(hparams, "best_" + metric + "_dir") +
                              '/translate.ckpt-' + new_top_score_name[0] + '*')
        utils.save_hparams(out_dir, hparams)
    return scores
Exemple #24
0
def train(hparams, scope=None, target_session="", compute_ppl=0):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:  # choose this model
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)
    wsd_src_file = "%s" % (hparams.sample_prefix)

    wsd_src_data = inference.load_data(wsd_src_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    '''
  run_full_eval(
      model_dir, infer_model, infer_sess,
      eval_model, eval_sess, hparams,
      summary_writer, sample_src_data,
      sample_tgt_data, avg_ckpts)
  '''
    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)
    end_step = global_step + 100
    while global_step < end_step:  # num_train_steps
        ### Run a step ###
        start_time = time.time()
        try:
            # then forward inference result to WSD, get reward
            step_result = loaded_train_model.train(train_sess)
            # forward reward to placeholder of loaded_train_model, and write a new train function where loss = loss*reward
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)

            # run_sample_decode(infer_model, infer_sess, model_dir, hparams,
            #                   summary_writer, sample_src_data, sample_tgt_data)

            # only for pretrain
            # run_external_eval(infer_model, infer_sess, model_dir, hparams,
            #                   summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})

            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result, hparams)
        summary_writer.add_summary(step_summary, global_step)
        if compute_ppl:
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)
        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)
    '''
Exemple #25
0
def train(hparams, scope=None, target_session=''):
    """Train the chatbot"""
    # Initialize some local hyperparameters
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if hparams.architecture == "simple":
        model_creator = SimpleModel
        get_infer_iterator = iterator_utils.get_infer_iterator
        get_iterator = iterator_utils.get_iterator
    elif hparams.architecture == "hier":
        model_creator = HierarchicalModel
        # Parse some of the arguments now
        def curry_get_infer_iterator(dataset, vocab_table, batch_size, src_reverse,
                       eos, src_max_len):
            return end2end_iterator_utils.get_infer_iterator(dataset, vocab_table, batch_size, src_reverse, eos,
                                                      src_max_len=src_max_len, eou=hparams.eou,
                                                      dialogue_max_len=hparams.dialogue_max_len)
        get_infer_iterator = curry_get_infer_iterator

        def curry_get_iterator(src_dataset,
                 tgt_dataset,
                 vocab_table,
                 batch_size,
                 sos,
                 eos,
                 src_reverse,
                 random_seed,
                 num_buckets,
                 src_max_len=None,
                 tgt_max_len=None,
                 num_threads=4,
                 output_buffer_size=None,
                 skip_count=None):
            return end2end_iterator_utils.get_iterator(src_dataset, tgt_dataset, vocab_table, batch_size, sos, eos,
                                                eou=hparams.eou, src_reverse=src_reverse, random_seed=random_seed,
                                                num_dialogue_buckets=num_buckets, src_max_len=src_max_len,
                                                tgt_max_len=tgt_max_len, num_threads=num_threads,
                                                output_buffer_size=output_buffer_size, skip_count=skip_count)

        get_iterator = curry_get_iterator
    else:
        raise ValueError("Unkown architecture", hparams.architecture)

    # Create three models which share parameters through the use of checkpoints
    train_model = create_train_model(model_creator, get_iterator, hparams, scope)
    eval_model = create_eval_model(model_creator, get_iterator, hparams, scope)
    infer_model = inference.create_infer_model(model_creator, get_infer_iterator, hparams, scope)
    # ToDo: adapt for architectures
    # Preload the data to use for sample decoding

    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    avg_step_time = 0.0

    # Create the configurations for the sessions
    config_proto = utils.get_config_proto(log_device_placement=log_device_placement)
    # Create three sessions, one for each model
    train_sess = tf.Session(target=target_session, config=config_proto, graph=train_model.graph)
    eval_sess = tf.Session(target=target_session, config=config_proto, graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session, config=config_proto, graph=infer_model.graph)

    # Load the train model from checkpoint or create a new one
    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(train_model.model, model_dir,
                                                                            train_sess, name="train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(
        os.path.join(out_dir, summary_name), train_model.graph)
    # First evaluation
    run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    # Initialize the hyperparameters for the loop.
    step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
    checkpoint_total_count = 0.0
    speed, train_ppl = 0.0, 0.0
    start_train_time = time.time()

    utils.print_out(
        "# Start step %d, lr %g, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         time.ctime()),
        log_f)

    # epoch_step records where we were within an epoch. Used to skip trained on examples
    skip_count = hparams.batch_size * hparams.epoch_step
    utils.print_out("# Init train iterator, skipping %d elements" % skip_count)
    # Initialize the training iterator
    train_sess.run(
        train_model.iterator.initializer,
        feed_dict={train_model.skip_count_placeholder: skip_count})

    # Train until we reach num_steps.
    while global_step < num_train_steps:
        # Run a step
        start_step_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            (_, step_loss, step_predict_count, step_summary, global_step,  # The _ is the output of the update op
             step_word_count, batch_size) = step_result
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            # Decode and print a random sentence
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Perform external evaluation to save checkpoints if this is the best for some metric
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)
            # Reinitialize the iterator from the beginning
            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Write step summary.
        summary_writer.add_summary(step_summary, global_step)

        # update statistics
        step_time += (time.time() - start_step_time)

        checkpoint_loss += (step_loss * batch_size)
        checkpoint_predict_count += step_predict_count
        checkpoint_total_count += float(step_word_count)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step

            # Print statistics for the previous epoch.
            avg_step_time = step_time / steps_per_stats
            train_ppl = utils.safe_exp(checkpoint_loss / checkpoint_predict_count)
            speed = checkpoint_total_count / (1000 * step_time)
            utils.print_out(
                "  global step %d lr %g "
                "step-time %.2fs wps %.2fK ppl %.2f %s" %
                (global_step,
                 loaded_train_model.learning_rate.eval(session=train_sess),
                 avg_step_time, speed, train_ppl, _get_best_results(hparams)),
                log_f)
            if math.isnan(train_ppl):
                # The model has screwed up
                break

            # Reset timer and loss.
            step_time, checkpoint_loss, checkpoint_predict_count = 0.0, 0.0, 0.0
            checkpoint_total_count = 0.0

        if global_step - last_eval_step >= steps_per_eval:
            # Perform evaluation. Start by reassigning the last_eval_step variable to the current step
            last_eval_step = global_step
            # Print the progress and add summary
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl", train_ppl)

            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run internal evaluation, and update the ppl variables. The data iterator is instantieted in the method.
            dev_ppl, test_ppl = run_internal_eval(eval_model, eval_sess, model_dir, hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            # Run the external evaluation
            last_external_eval_step = global_step
            # Save checkpoint
            loaded_train_model.saver.save(train_sess, os.path.join(out_dir, "chatbot.ckpt"), global_step=global_step)
            # Decode and print a random sample
            run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
                              sample_src_data, sample_tgt_data)
            # Run external evaluation, updating metric scores in the meanwhile. The unneeded output is the global step.
            dev_scores, test_scores, _ = run_external_evaluation(infer_model, infer_sess, model_dir, hparams,
                                                                 summary_writer, save_on_best_dev=True)

    # Done training. Save the model
    loaded_train_model.saver.save(
        train_sess,
        os.path.join(out_dir, "chatbot.ckpt"),
        global_step=global_step)

    result_summary, _, dev_scores, test_scores, dev_ppl, test_ppl = run_full_eval(
        model_dir, infer_model, infer_sess,
        eval_model, eval_sess, hparams,
        summary_writer, sample_src_data,
        sample_tgt_data)
    utils.print_out(
        "# Final, step %d lr %g "
        "step-time %.2f wps %.2fK ppl %.2f, %s, %s" %
        (global_step, loaded_train_model.learning_rate.eval(session=train_sess),
         avg_step_time, speed, train_ppl, result_summary, time.ctime()),
        log_f)
    utils.print_time("# Done training!", start_train_time)

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        result_summary, best_global_step, _, _, _, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
            summary_writer, sample_src_data, sample_tgt_data)
        utils.print_out("# Best %s, step %d "
                        "step-time %.2f wps %.2fK, %s, %s" %
                        (metric, best_global_step, avg_step_time, speed,
                         result_summary, time.ctime()), log_f)

    summary_writer.close()
    return (dev_scores, test_scores, dev_ppl, test_ppl, global_step)
Exemple #26
0
def train(hparams, scope=None, target_session=""):
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:
        if hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model = model_util.create_train_model(model_creator, hparams, scope)
    eval_model = model_util.create_eval_model(model_creator, hparams, scope)
    infer_model = model_util.create_infer_model(model_creator, hparams, scope)

    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_util.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data,
                  avg_ckpts)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)

    while global_step < num_train_steps:
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_perplexity",
                              info["train_perplexity"])

            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    result_summary, _, final_eval_metrics = run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts)
    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("# Best %s, " % metric, best_global_step, info,
                        result_summary, log_f)
        summary_writer.close()

        if avg_ckpts:
            best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir")
            summary_writer = tf.summary.FileWriter(
                os.path.join(best_model_dir, summary_name), infer_model.graph)
            result_summary, best_global_step, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            print_step_info("# Averaged Best %s, " % metric, best_global_step,
                            info, result_summary, log_f)
            summary_writer.close()

    return final_eval_metrics, global_step