Beispiel #1
0
def single_worker_inference(infer_model, ckpt, inference_input_file,
                            inference_context_file, inference_output_file,
                            hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)
    if inference_context_file is not None:
        infer_context = load_data(inference_context_file, hparams)
    else:
        infer_context = None

    infer_feed_dict = {
        infer_model.src_placeholder: infer_data,
        infer_model.batch_size_placeholder: hparams.infer_batch_size
    }
    if infer_context is not None:
        infer_feed_dict[infer_model.ctx_placeholder] = infer_context

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer, feed_dict=infer_feed_dict)
        # Decode
        utils.print_out("# Start decoding")
        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
        else:
            nmt_utils.decode_and_evaluate(
                "infer",
                loaded_infer_model,
                sess,
                output_infer,
                ref_file=None,
                metrics=hparams.metrics,
                subword_option=hparams.subword_option,
                beam_width=hparams.beam_width,
                tgt_eos=hparams.eos,
                hparams=hparams,
                num_translations_per_input=hparams.num_translations_per_input)
Beispiel #2
0
    def __init__(self, hparams_path, source_prefix, out_dir,
                 environment_server_address):
        """Constructor for the reformulator.

    Args:
      hparams_path: Path to json hparams file.
      source_prefix: A prefix that is added to every question before
        translation which should be used for adding tags like <en> <2en>.
        Can be empty or None in which case the prefix is ignored.
      out_dir: Directory where the model output will be written.
      environment_server_address: Address of the environment server.

    Raises:
      ValueError: if model architecture is not known.
    """

        self.hparams = load_hparams(hparams_path, out_dir)
        assert self.hparams.num_buckets == 1, "No bucketing when in server mode."
        assert not self.hparams.server_mode, (
            "server_mode set to True but not "
            "running as server.")

        self.hparams.environment_server = environment_server_address
        if self.hparams.subword_option == "spm":
            self.sentpiece = sentencepiece_processor.SentencePieceProcessor()
            self.sentpiece.Load(self.hparams.subword_model.encode("utf-8"))
        self.source_prefix = source_prefix

        # Create the model
        if not self.hparams.attention:
            model_creator = nmt_model.Model
        elif self.hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        elif self.hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
            model_creator = gnmt_model.GNMTModel
        else:
            raise ValueError("Unknown model architecture")

        self.trie = trie_decoder_utils.DecoderTrie(
            vocab_path=self.hparams.tgt_vocab_file,
            eos_token=self.hparams.eos,
            subword_option=self.hparams.subword_option,
            subword_model=self.hparams.get("subword_model"),
            optimize_ngrams_len=self.hparams.optimize_ngrams_len)
        if self.hparams.trie_path is not None and tf.gfile.Exists(
                self.hparams.trie_path):
            self.trie.populate_from_text_file(self.hparams.trie_path)

        combined_graph = tf.Graph()
        self.train_model = model_helper.create_train_model(
            model_creator,
            self.hparams,
            graph=combined_graph,
            trie=self.trie,
            use_placeholders=True)

        # Create different inference models for beam search, sampling and greedy
        # decoding.
        default_infer_mode = self.hparams.infer_mode
        default_beam_width = self.hparams.beam_width
        self.infer_models = {}
        self.hparams.use_rl = False
        self.hparams.infer_mode = "greedy"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          GREEDY] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph)

        self.hparams.infer_mode = "sample"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          SAMPLING] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph)

        self.hparams.infer_mode = "beam_search"
        self.hparams.beam_width = max(1, default_beam_width)
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          BEAM_SEARCH] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph)

        self.hparams.infer_mode = "trie_greedy"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          TRIE_GREEDY] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph,
                              trie=self.trie)
        self.hparams.infer_mode = default_infer_mode
        self.hparams.beam_width = default_beam_width

        self.hparams.infer_mode = "trie_sample"
        self.hparams.beam_width = 0
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          TRIE_SAMPLE] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph,
                              trie=self.trie)

        self.hparams.infer_mode = "trie_beam_search"
        self.hparams.beam_width = max(1, default_beam_width)
        self.infer_models[reformulator_pb2.ReformulatorRequest.
                          TRIE_BEAM_SEARCH] = model_helper.create_infer_model(
                              model_creator,
                              self.hparams,
                              graph=combined_graph,
                              trie=self.trie)

        self.hparams.use_rl = True
        self.sess = tf.Session(graph=combined_graph,
                               config=misc_utils.get_config_proto())

        with combined_graph.as_default():
            # p1 = "C:/Users/aanamika/Documents/QuestionGeneration/active-qa-master/tmp/active-qa/translate.ckpt-1460356"
            p1 = 'C:/Users/aanamika/Documents/QuestionGeneration/active-qa-master/tmp/active-qa/temp/translate.ckpt-1460356'
            print('p1:', p1)
            self.sess.run(tf.global_variables_initializer())
            self.sess.run(tf.tables_initializer())
            _, global_step = model_helper.create_or_load_model(
                self.train_model.model, p1, self.sess, "train")
            # self.train_model.model, out_dir, self.sess, "train")
            self.last_save_step = global_step

        self.summary_writer = tf.summary.FileWriter(
            os.path.join(out_dir, "train_log"), self.train_model.graph)
        self.checkpoint_path = os.path.join(out_dir, "translate.ckpt")
        self.trie_save_path = os.path.join(out_dir, "trie")
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    combined_graph = tf.Graph()
    train_model = model_helper.create_train_model(model_creator,
                                                  hparams,
                                                  scope,
                                                  graph=combined_graph)
    eval_model = model_helper.create_eval_model(model_creator,
                                                hparams,
                                                scope,
                                                graph=combined_graph)
    infer_model = model_helper.create_infer_model(model_creator,
                                                  hparams,
                                                  scope,
                                                  graph=combined_graph)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    dev_ctx_file = None
    if hparams.ctx is not None:
        dev_ctx_file = "%s.%s" % (hparams.dev_prefix, hparams.ctx)

    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)
    sample_ctx_data = None
    if dev_ctx_file is not None:
        sample_ctx_data = inference.load_data(dev_ctx_file)

    sample_annot_data = None
    if hparams.dev_annotations is not None:
        sample_annot_data = inference.load_data(hparams.dev_annotations)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    sess = tf.Session(target=target_session,
                      config=config_proto,
                      graph=combined_graph)

    with train_model.graph.as_default():
        sess.run(tf.global_variables_initializer())
        sess.run(tf.tables_initializer())
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(infer_model, sess, eval_model, sess, hparams, summary_writer,
                  sample_src_data, sample_ctx_data, sample_tgt_data,
                  sample_annot_data)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, sess,
                                                 global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            run_sample_decode(infer_model, sess, hparams, summary_writer,
                              sample_src_data, sample_ctx_data,
                              sample_tgt_data, sample_annot_data)
            run_external_eval(infer_model, sess, hparams, summary_writer)

            sess.run(train_model.iterator.initializer,
                     feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, sess, hparams, summary_writer,
                              sample_src_data, sample_ctx_data,
                              sample_tgt_data, sample_annot_data)

            dev_ppl, test_ppl = None, None
            # only evaluate perplexity when supervised learning
            if not hparams.use_rl:
                dev_ppl, test_ppl = run_internal_eval(eval_model, sess,
                                                      hparams, summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, sess, hparams, summary_writer,
                              sample_src_data, sample_ctx_data,
                              sample_tgt_data, sample_annot_data)
            run_external_eval(infer_model, sess, hparams, summary_writer)

    # Done training
    loaded_train_model.saver.save(sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = (run_full_eval(
        infer_model, sess, eval_model, sess, hparams, summary_writer,
        sample_src_data, sample_ctx_data, sample_tgt_data, sample_annot_data))

    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()
    return final_eval_metrics, global_step
Beispiel #4
0
def multi_worker_inference(infer_model, ckpt, inference_input_file,
                           inference_context_file, inference_output_file,
                           hparams, num_workers, jobid):
    """Inference using multiple workers."""
    assert num_workers > 1

    final_output_infer = inference_output_file
    output_infer = "%s_%d" % (inference_output_file, jobid)
    output_infer_done = "%s_done_%d" % (inference_output_file, jobid)

    # Read data
    infer_data = load_data(inference_input_file, hparams)
    if inference_context_file is not None:
        infer_context = load_data(inference_context_file, hparams)
    else:
        infer_context = None
    # Split data to multiple workers
    total_load = len(infer_data)
    load_per_worker = int((total_load - 1) / num_workers) + 1
    start_position = jobid * load_per_worker
    end_position = min(start_position + load_per_worker, total_load)
    infer_data = infer_data[start_position:end_position]
    if infer_context is not None:
        infer_context = infer_context[start_position:end_position]

    infer_feed_dict = {
        infer_model.src_placeholder: infer_data,
        infer_model.batch_size_placeholder: hparams.infer_batch_size
    }
    if infer_context is not None:
        infer_feed_dict[infer_model.ctx_placeholder] = infer_context

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(infer_model.model, ckpt,
                                                     sess, "infer")
        sess.run(infer_model.iterator.initializer, feed_dict=infer_feed_dict)
        # Decode
        utils.print_out("# Start decoding")
        nmt_utils.decode_and_evaluate(
            "infer",
            loaded_infer_model,
            sess,
            output_infer,
            ref_file=None,
            metrics=hparams.metrics,
            subword_option=hparams.subword_option,
            beam_width=hparams.beam_width,
            tgt_eos=hparams.eos,
            hparams=hparams,
            num_translations_per_input=hparams.num_translations_per_input)

        # Change file name to indicate the file writing is completed.
        tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)

        # Job 0 is responsible for the clean up.
        if jobid != 0:
            return

        # Now write all translations
        with codecs.getwriter("utf-8")(tf.gfile.GFile(final_output_infer,
                                                      mode="wb")) as final_f:
            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                while not tf.gfile.Exists(worker_infer_done):
                    utils.print_out("  waitting job %d to complete." %
                                    worker_id)
                    time.sleep(10)

                with codecs.getreader("utf-8")(tf.gfile.GFile(
                        worker_infer_done, mode="rb")) as f:
                    for translation in f:
                        final_f.write("%s" % translation)

            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file,
                                                    worker_id)
                tf.gfile.Remove(worker_infer_done)