Beispiel #1
0
def _external_eval(model,
                   global_step,
                   sess,
                   hparams,
                   iterator,
                   iterator_feed_dict,
                   tgt_file,
                   label,
                   summary_writer,
                   save_on_best,
                   avg_ckpts=False):
    """External evaluation such as BLEU and ROUGE scores."""
    out_dir = hparams.out_dir

    if avg_ckpts:
        label = "avg_" + label

    utils.print_out("# External evaluation, global step %d" % global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    output = os.path.join(out_dir, "output_%s" % label)
    scores = nmt_utils.decode_and_evaluate(
        label,
        model,
        sess,
        output,
        ref_file=tgt_file,
        metrics=hparams.metrics,
        subword_option=hparams.subword_option,
        beam_width=hparams.beam_width,
        tgt_eos=hparams.eos,
        hparams=hparams,
        decode=True)
    # Save on best metrics
    if global_step > 0:
        for metric in hparams.metrics:
            if avg_ckpts:
                best_metric_label = "avg_best_" + metric
            else:
                best_metric_label = "best_" + metric

            utils.add_summary(summary_writer, global_step,
                              "%s_%s" % (label, metric), scores[metric])
            # metric: larger is better
            if save_on_best and scores[metric] > getattr(
                    hparams, best_metric_label):
                setattr(hparams, best_metric_label, scores[metric])
                model.saver.save(sess,
                                 os.path.join(
                                     getattr(hparams,
                                             "best_" + metric + "_dir"),
                                     "translate.ckpt"),
                                 global_step=model.global_step)
        utils.save_hparams(out_dir, hparams)
    return scores
def _internal_eval(hparams, model, global_step, sess, iterator,
                   iterator_feed_dict, summary_writer, label):
    """Computing perplexity."""

    utils.print_out("# Internal evaluation (perplexity), global step %d" %
                    global_step)

    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)
    ppl = model_helper.compute_perplexity(hparams, model, sess, label)
    utils.add_summary(summary_writer, global_step, "%s_ppl" % label, ppl)
    return ppl
Beispiel #3
0
def main(argv):
    del argv  # Unused.

    if FLAGS.debug:
        random.seed(0)

    reformulator_instance = reformulator.Reformulator(
        hparams_path=FLAGS.hparams_path,
        source_prefix=FLAGS.source_prefix,
        out_dir=FLAGS.out_dir,
        environment_server_address=FLAGS.environment_server_address)
    environment_fn = environment_client.make_environment_reward_fn(
        FLAGS.environment_server_address,
        mode=FLAGS.mode,
        env_call_parallelism=FLAGS.env_sample_parallelism)

    eval_environment_fn = environment_client.make_environment_reward_fn(
        FLAGS.environment_server_address,
        mode='searchqa',
        env_call_parallelism=FLAGS.env_eval_parallelism)

    # Read data.
    questions, annotations, docid_2_answer = read_data(
        questions_file=FLAGS.train_questions,
        annotations_file=FLAGS.train_annotations,
        answers_file=FLAGS.train_data,
        preprocessing_mode=FLAGS.mode)
    dev_questions, dev_annotations, dev_docid_2_answer = read_data(
        questions_file=FLAGS.dev_questions,
        annotations_file=FLAGS.dev_annotations,
        answers_file=FLAGS.dev_data,
        preprocessing_mode=FLAGS.mode,
        max_lines=FLAGS.max_dev_examples)

    # Summary writer that writes events to a folder. TensorBoard will later read
    # from it.
    summary_writer = tf.summary.FileWriter(
        os.path.join(
            FLAGS.tensorboard_dir,
            'reformulator_and_selector_training_log_' + str(time.time())))

    if FLAGS.enable_selector_training:
        selector_model = selector.Selector()
        last_save_step = 0

    global_step = 0
    for epoch in range(FLAGS.epochs):
        for batch_id, (questions_batch, annotations_batch) in enumerate(
                batch(questions, annotations, FLAGS.batch_size_train)):
            # Run eval every num_steps_per_eval batches.
            if global_step % FLAGS.num_steps_per_eval is 0:
                if FLAGS.debug:
                    print('Running eval...')
                eval_start_time = time.time()

                if not FLAGS.enable_selector_training:
                    eval_f1_avg = _run_reformulator_eval(
                        dev_questions, dev_annotations, reformulator_instance,
                        environment_fn, FLAGS.batch_size_eval)
                else:
                    eval_f1_avg = _run_eval_with_selector(
                        questions=dev_questions,
                        annotations=dev_annotations,
                        docid_2_answer=dev_docid_2_answer,
                        reformulator_instance=reformulator_instance,
                        selector_model=selector_model,
                        batch_size=FLAGS.batch_size_eval,
                        environment_fn=eval_environment_fn)

                # Correct the average F1 score for deleted datapoints in the SearchQA
                # dataset.
                if FLAGS.mode == 'searchqa':
                    eval_f1_avg = _correct_searchqa_score(eval_f1_avg,
                                                          dataset='dev')

                eval_time = time.time() - eval_start_time

                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='eval_f1_avg',
                                       value=eval_f1_avg)
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='eval_time',
                                       value=eval_time)

                if FLAGS.debug:
                    print('Avg F1 on dev: {}.'.format(eval_f1_avg))
                    print('Time to finish eval: {}'.format(eval_time))

            start_time = time.time()
            if FLAGS.debug:
                print('Epoch {}, Batch {}.'.format(epoch, batch_id))
                print('Question: [{}]; Id: {}'.format(questions_batch[0],
                                                      annotations_batch[0]))

            # Retrieve rewrites for selector training using beam search.
            if FLAGS.enable_selector_training:
                responses_beam = reformulator_instance.reformulate(
                    questions=questions_batch,
                    inference_mode=reformulator_pb2.ReformulatorRequest.
                    BEAM_SEARCH)

                # Discard answers.
                reformulations_beam = [[rf.reformulation for rf in rsp]
                                       for rsp in responses_beam]

            if FLAGS.enable_reformulator_training:
                # Train reformulator model.
                if FLAGS.debug:
                    print('Training reformulator...')

                reformulator_loss, f1s, reformulations = reformulator_instance.train(
                    sources=questions_batch, annotations=annotations_batch)

                f1_avg = f1s.mean()

                if [] in reformulations:
                    if FLAGS.debug:
                        print('Found empty rewrites! Skipping this batch.')
                    continue

                if FLAGS.debug:
                    print('Rewrite: {}'.format(safe_string(reformulations[0])))
                    print('Avg F1: {}'.format(f1_avg))
                    print('Loss  : {}'.format(reformulator_loss))

                # Write the f1_avg and loss to Tensorboard.
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='f1_avg',
                                       value=f1_avg)
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='reformulator_loss',
                                       value=reformulator_loss)

            # Train selector model.
            if FLAGS.enable_selector_training:
                (selector_questions, selector_answers,
                 selector_scores) = query_environment(
                     original_questions=questions_batch,
                     rewrites=reformulations_beam,
                     annotations=annotations_batch,
                     environment_fn=eval_environment_fn,
                     docid_2_answer=docid_2_answer,
                     token_level_f1_scores=False)

                if FLAGS.debug:
                    print('Training selector...')

                train_selector_loss, train_selector_accuracy = selector_model.train(
                    selector_questions, selector_answers, selector_scores)

                # Regularly save a checkpoint.
                if global_step - last_save_step >= FLAGS.steps_per_save_selector:
                    selector_model.save(str(global_step))
                    last_save_step = global_step
                    print('Selector saved at step: {}'.format(global_step))

                if FLAGS.debug:
                    print('Train Accuracy: {}'.format(train_selector_accuracy))
                    print('Train Loss    : {}'.format(train_selector_loss))

                # Write the accuracy and loss to Tensorboard.
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='train_selector_accuracy',
                                       value=train_selector_accuracy)
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='train_selector_loss',
                                       value=train_selector_loss)

            iteration_time = time.time() - start_time
            if FLAGS.debug:
                print('Iteration time: {}'.format(iteration_time))

            misc_utils.add_summary(summary_writer,
                                   global_step,
                                   tag='iteration_time',
                                   value=iteration_time)

            # Increment the global counter
            global_step += 1
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    combined_graph = tf.Graph()
    train_model = model_helper.create_train_model(model_creator,
                                                  hparams,
                                                  scope,
                                                  graph=combined_graph)
    eval_model = model_helper.create_eval_model(model_creator,
                                                hparams,
                                                scope,
                                                graph=combined_graph)
    infer_model = model_helper.create_infer_model(model_creator,
                                                  hparams,
                                                  scope,
                                                  graph=combined_graph)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    dev_ctx_file = None
    if hparams.ctx is not None:
        dev_ctx_file = "%s.%s" % (hparams.dev_prefix, hparams.ctx)

    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)
    sample_ctx_data = None
    if dev_ctx_file is not None:
        sample_ctx_data = inference.load_data(dev_ctx_file)

    sample_annot_data = None
    if hparams.dev_annotations is not None:
        sample_annot_data = inference.load_data(hparams.dev_annotations)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    sess = tf.Session(target=target_session,
                      config=config_proto,
                      graph=combined_graph)

    with train_model.graph.as_default():
        sess.run(tf.global_variables_initializer())
        sess.run(tf.tables_initializer())
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(infer_model, sess, eval_model, sess, hparams, summary_writer,
                  sample_src_data, sample_ctx_data, sample_tgt_data,
                  sample_annot_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, sess,
                                                 global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            run_sample_decode(infer_model, sess, hparams, summary_writer,
                              sample_src_data, sample_ctx_data,
                              sample_tgt_data, sample_annot_data)
            run_external_eval(infer_model, sess, hparams, summary_writer)

            sess.run(train_model.iterator.initializer,
                     feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, sess, hparams, summary_writer,
                              sample_src_data, sample_ctx_data,
                              sample_tgt_data, sample_annot_data)

            dev_ppl, test_ppl = None, None
            # only evaluate perplexity when supervised learning
            if not hparams.use_rl:
                dev_ppl, test_ppl = run_internal_eval(eval_model, sess,
                                                      hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, sess, hparams, summary_writer,
                              sample_src_data, sample_ctx_data,
                              sample_tgt_data, sample_annot_data)
            run_external_eval(infer_model, sess, hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = (run_full_eval(
        infer_model, sess, eval_model, sess, hparams, summary_writer,
        sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data))

    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()
    return final_eval_metrics, global_step