Beispiel #1
0
    def __init__(self, hparams_path, source_prefix, out_dir,
                 environment_server_address):
        """Constructor for the reformulator.

    Args:
      hparams_path: Path to json hparams file.
      source_prefix: A prefix that is added to every question before
        translation which should be used for adding tags like <en> <2en>.
        Can be empty or None in which case the prefix is ignored.
      out_dir: Directory where the model output will be written.
      environment_server_address: Address of the environment server.

    Raises:
      ValueError: if model architecture is not known.
    """

        self.hparams = load_hparams(hparams_path, out_dir)
        assert self.hparams.num_buckets == 1, "No bucketing when in server mode."
        assert not self.hparams.server_mode, (
            "server_mode set to True but not "
            "running as server.")

        self.hparams.environment_server = environment_server_address
        if self.hparams.subword_option == "spm":
            self.sentpiece = sentencepiece_processor.SentencePieceProcessor()
            self.sentpiece.Load(self.hparams.subword_model.encode("utf-8"))
        self.source_prefix = source_prefix

        # Create the model
        if not self.hparams.attention:
            model_creator = nmt_model.Model
        elif self.hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        elif self.hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
            model_creator = gnmt_model.GNMTModel
        else:
            raise ValueError("Unknown model architecture")

        self.trie = trie_decoder_utils.DecoderTrie(
            vocab_path=self.hparams.tgt_vocab_file,
            eos_token=self.hparams.eos,
            subword_option=self.hparams.subword_option,
            subword_model=self.hparams.get("subword_model"),
            optimize_ngrams_len=self.hparams.optimize_ngrams_len)
        if self.hparams.trie_path is not None and tf.gfile.Exists(
                self.hparams.trie_path):
            self.trie.populate_from_text_file(self.hparams.trie_path)

        combined_graph = tf.Graph()
        self.train_model = model_helper.create_train_model(
            model_creator,
            self.hparams,
            graph=combined_graph,
            trie=self.trie,
            use_placeholders=True)

        # Create different inference models for beam search, sampling and greedy
        # decoding.
        default_infer_mode = self.hparams.infer_mode
        default_beam_width = self.hparams.beam_width
        self.infer_models = {}
        self.hparams.use_rl = False
        self.hparams.infer_mode = "greedy"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          GREEDY] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph)

        self.hparams.infer_mode = "sample"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          SAMPLING] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph)

        self.hparams.infer_mode = "beam_search"
        self.hparams.beam_width = max(1, default_beam_width)
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          BEAM_SEARCH] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph)

        self.hparams.infer_mode = "trie_greedy"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          TRIE_GREEDY] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph,
                              trie=self.trie)
        self.hparams.infer_mode = default_infer_mode
        self.hparams.beam_width = default_beam_width

        self.hparams.infer_mode = "trie_sample"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          TRIE_SAMPLE] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph,
                              trie=self.trie)

        self.hparams.infer_mode = "trie_beam_search"
        self.hparams.beam_width = max(1, default_beam_width)
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          TRIE_BEAM_SEARCH] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph,
                              trie=self.trie)

        self.hparams.use_rl = True
        self.sess = tf.Session(graph=combined_graph,
                               config=misc_utils.get_config_proto())

        with combined_graph.as_default():
            # p1 = "C:/Users/aanamika/Documents/QuestionGeneration/active-qa-master/tmp/active-qa/translate.ckpt-1460356"
            p1 = 'C:/Users/aanamika/Documents/QuestionGeneration/active-qa-master/tmp/active-qa/temp/translate.ckpt-1460356'
            print('p1:', p1)
            self.sess.run(tf.global_variables_initializer())
            self.sess.run(tf.tables_initializer())
            _, global_step = model_helper.create_or_load_model(
                self.train_model.model, p1, self.sess, "train")
            # self.train_model.model, out_dir, self.sess, "train")
            self.last_save_step = global_step

        self.summary_writer = tf.summary.FileWriter(
            os.path.join(out_dir, "train_log"), self.train_model.graph)
        self.checkpoint_path = os.path.join(out_dir, "translate.ckpt")
        self.trie_save_path = os.path.join(out_dir, "trie")
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    combined_graph = tf.Graph()
    train_model = model_helper.create_train_model(model_creator,
                                                  hparams,
                                                  scope,
                                                  graph=combined_graph)
    eval_model = model_helper.create_eval_model(model_creator,
                                                hparams,
                                                scope,
                                                graph=combined_graph)
    infer_model = model_helper.create_infer_model(model_creator,
                                                  hparams,
                                                  scope,
                                                  graph=combined_graph)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    dev_ctx_file = None
    if hparams.ctx is not None:
        dev_ctx_file = "%s.%s" % (hparams.dev_prefix, hparams.ctx)

    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)
    sample_ctx_data = None
    if dev_ctx_file is not None:
        sample_ctx_data = inference.load_data(dev_ctx_file)

    sample_annot_data = None
    if hparams.dev_annotations is not None:
        sample_annot_data = inference.load_data(hparams.dev_annotations)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    sess = tf.Session(target=target_session,
                      config=config_proto,
                      graph=combined_graph)

    with train_model.graph.as_default():
        sess.run(tf.global_variables_initializer())
        sess.run(tf.tables_initializer())
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(infer_model, sess, eval_model, sess, hparams, summary_writer,
                  sample_src_data, sample_ctx_data, sample_tgt_data,
                  sample_annot_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, sess,
                                                 global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            run_sample_decode(infer_model, sess, hparams, summary_writer,
                              sample_src_data, sample_ctx_data,
                              sample_tgt_data, sample_annot_data)
            run_external_eval(infer_model, sess, hparams, summary_writer)

            sess.run(train_model.iterator.initializer,
                     feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, sess, hparams, summary_writer,
                              sample_src_data, sample_ctx_data,
                              sample_tgt_data, sample_annot_data)

            dev_ppl, test_ppl = None, None
            # only evaluate perplexity when supervised learning
            if not hparams.use_rl:
                dev_ppl, test_ppl = run_internal_eval(eval_model, sess,
                                                      hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, sess, hparams, summary_writer,
                              sample_src_data, sample_ctx_data,
                              sample_tgt_data, sample_annot_data)
            run_external_eval(infer_model, sess, hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = (run_full_eval(
        infer_model, sess, eval_model, sess, hparams, summary_writer,
        sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data))

    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()
    return final_eval_metrics, global_step