Exemple #1
0
    def nmt_main(self, flags, default_hparams, scope=None):
        out_dir = flags.out_dir
        if not tf.gfile.Exists(out_dir):
            tf.gfile.MakeDirs(out_dir)

        self.hparams = nmt.create_or_load_hparams(out_dir,
                                                  default_hparams,
                                                  flags.hparams_path,
                                                  save_hparams=False)

        self.ckpt = tf.train.latest_checkpoint(out_dir)
        if not self.ckpt:
            print('Train is needed')
            sys.exit()

        hparams = self.hparams
        if not hparams.attention:
            model_creator = model.Model
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
            model_creator = gnmt_model.GNMTModel
        else:
            raise ValueError("Unknown model architecture")
        self.infer_model = model_helper.create_infer_model(
            model_creator, hparams, scope)

        self.sess = tf.Session(graph=self.infer_model.graph,
                               config=utils.get_config_proto())

        with self.infer_model.graph.as_default():
            self.loaded_infer_model = model_helper.load_model(
                self.infer_model.model, self.ckpt, self.sess, 'infer')
Exemple #2
0
def predicate(ckpt,
              hparams,
              num_workers=1,
              jobid=0,
              scope=None):
    """Perform translation."""
    if hparams.inference_indices:
        assert num_workers == 1

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")
    infer_model = model_helper.create_infer_model(model_creator, hparams, scope)

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(
            infer_model.model, ckpt, sess, "infer")
        while True:
            input_data = input("translate>")
            res = translate_and_return(hparams,
                                       infer_model,
                                       [input_data.lower()],
                                       loaded_infer_model,
                                       sess)
            print("result: %s" % res.decode("utf-8"))
    def __init__(self, bot):
        logging.basicConfig(level=logging.INFO)
        out_dir = '<path_to_model>'

        self.ping_replace = re.compile(r"<@![0-9]{2,}>", re.IGNORECASE)
        self.bot = bot

        nmt_parser = argparse.ArgumentParser()
        add_arguments(nmt_parser)
        flags, unparsed = nmt_parser.parse_known_args()
        default_hparams = create_hparams(flags)
        self.hparams = create_or_load_hparams(out_dir,
                                              default_hparams,
                                              flags.hparams_path,
                                              save_hparams=False)
        ckpt = tf.train.latest_checkpoint(out_dir)

        if not self.hparams.attention:
            model_creator = nmt_model.Model
        elif self.hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        elif self.hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
            model_creator = gnmt_model.GNMTModel
        else:
            raise ValueError("Unknown model architecture")
        self.infer_model = model_helper.create_infer_model(
            model_creator, self.hparams, None)
        self.session = tf.InteractiveSession(graph=self.infer_model.graph,
                                             config=utils.get_config_proto())
        self.loaded_infer_model = model_helper.load_model(
            self.infer_model.model, ckpt, self.session, "infer")
Exemple #4
0
 def __init__(self):
     hparams = load_hparams('/tmp/nmt_model')
     ckpt = tf.train.latest_checkpoint('/tmp/nmt_model')
     self.model = create_infer_model(Model, hparams)
     self.sess = tf.Session(graph=self.model.graph,
                            config=get_config_proto())
     with self.model.graph.as_default():
         self.loaded_infer_model = load_model(self.model.model, ckpt,
                                              self.sess, "infer")
Exemple #5
0
def single_worker_inference(infer_model,
                            ckpt,
                            inference_input_file,
                            inference_output_file,
                            hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    infer_data = ["Lúc đấy tôi nghĩ chuyện này sẽ khó khăn gian khổ đây ."]

    with tf.Session(
            graph=infer_model.graph, config=utils.get_config_proto()) as sess:
        while True:
            # infer_data = ["Lúc đấy tôi nghĩ chuyện này sẽ khó khăn gian khổ đây ."]
            var = input("Input Vi Src: ")
            infer_data = [var]
            loaded_infer_model = model_helper.load_model(
                infer_model.model, ckpt, sess, "infer")
            sess.run(
                infer_model.iterator.initializer,
                feed_dict={
                    infer_model.src_placeholder: infer_data,
                    infer_model.batch_size_placeholder: hparams.infer_batch_size
                })
            # Decode
            utils.print_out("# Start decoding")
            if hparams.inference_indices:
                _decode_inference_indices(
                    loaded_infer_model,
                    sess,
                    output_infer=output_infer,
                    output_infer_summary_prefix=output_infer,
                    inference_indices=hparams.inference_indices,
                    tgt_eos=hparams.eos,
                    subword_option=hparams.subword_option)
            else:
                nmt_utils.decode_and_evaluate(
                    "infer",
                    loaded_infer_model,
                    sess,
                    output_infer,
                    ref_file=None,
                    metrics=hparams.metrics,
                    subword_option=hparams.subword_option,
                    beam_width=hparams.beam_width,
                    tgt_eos=hparams.eos,
                    num_translations_per_input=hparams.num_translations_per_input)
Exemple #6
0
def translate(ckpt,
              infer_data,
              inference_output_file,
              hparams,
              num_workers=1,
              jobid=0,
              scope=None):
    """Inference with a single worker."""
    output_infer = inference_output_file

    """Perform translation."""
    if hparams.inference_indices:
        assert num_workers == 1

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")
    infer_model = model_helper.create_infer_model(model_creator, hparams, scope)

    with tf.Session(
            graph=infer_model.graph, config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(
            infer_model.model, ckpt, sess, "infer")
        # Encode Data
        sess.run(
            infer_model.iterator.initializer,
            feed_dict={
                infer_model.src_placeholder: infer_data,
                infer_model.batch_size_placeholder: hparams.infer_batch_size
            })
        # Decode
        utils.print_out("# Start decoding")
        return nmt_utils.decode_and_return(
            "infer",
            loaded_infer_model,
            sess,
            output_infer,
            ref_file=None,
            metrics=hparams.metrics,
            subword_option=hparams.subword_option,
            beam_width=hparams.beam_width,
            tgt_eos=hparams.eos,
            num_translations_per_input=hparams.num_translations_per_input)
Exemple #7
0
    def setup(self, flags):
        # Model output directory
        out_dir = flags.out_dir
        if out_dir and not tf.gfile.Exists(out_dir):
            tf.gfile.MakeDirs(out_dir)

        # Load hparams.
        default_hparams = create_hparams(flags)
        loaded_hparams = False
        if flags.ckpt:  # Try to load hparams from the same directory as ckpt
            ckpt_dir = os.path.dirname(flags.ckpt)
            ckpt_hparams_file = os.path.join(ckpt_dir, "hparams")
            if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path:
                # Note: for some reason this will create an empty "best_bleu" directory and copy vocab files
                hparams = create_or_load_hparams(ckpt_dir,
                                                 default_hparams,
                                                 flags.hparams_path,
                                                 save_hparams=False)
                loaded_hparams = True

        assert loaded_hparams

        # GPU device
        config_proto = utils.get_config_proto(
            allow_soft_placement=True,
            num_intra_threads=hparams.num_intra_threads,
            num_inter_threads=hparams.num_inter_threads)
        utils.print_out("# Devices visible to TensorFlow: %s" %
                        repr(tf.Session(config=config_proto).list_devices()))

        # Inference indices (inference_indices is broken, but without setting it to None we'll crash)
        hparams.inference_indices = None

        # Create the graph
        model_creator = get_model_creator(hparams)
        infer_model = model_helper.create_infer_model(model_creator,
                                                      hparams,
                                                      scope=None)
        sess, loaded_infer_model = start_sess_and_load_model(
            infer_model, flags.ckpt, hparams)

        # Parameters needed by TF GNMT
        self.hparams = hparams

        self.infer_model = infer_model
        self.sess = sess
        self.loaded_infer_model = loaded_infer_model
Exemple #8
0
def generate_reply(input_text, flags):
    # Format data
    tokenized_text = tokenize_text(input_text)
    infer_data = [tokenized_text]
    # Load hparams.
    jobid = flags.jobid
    default_hparams = nmt.create_hparams(flags)
    hparams = nmt.create_or_load_hparams(model_dir,
                                         default_hparams,
                                         flags.hparams_path,
                                         save_hparams=(jobid == 0))
    # Load checkpoint
    ckpt = tf.train.latest_checkpoint(model_dir)
    # Inference
    model_creator = attention_model.AttentionModel
    # Create model
    scope = None
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    with tf.Session(graph=infer_model.graph,
                    config=misc_utils.get_config_proto()) as sess:
        model = model_helper.load_model(infer_model.model, ckpt, sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        nmt_outputs, infer_summary = model.decode(sess)
        # get text translation(reply as a chatbot)
        assert nmt_outputs.shape[0] == 1
        translation = nmt_utils.get_translation(
            nmt_outputs,
            sent_id=0,
            tgt_eos=hparams.eos,
            subword_option=hparams.subword_option)

    return translation.decode("utf-8")
Exemple #9
0
def multi_worker_inference(infer_model,
                           ckpt,
                           inference_input_file,
                           inference_output_file,
                           hparams,
                           num_workers,
                           jobid):
    """Inference using multiple workers."""
    assert num_workers > 1

    final_output_infer = inference_output_file
    output_infer = "%s_%d" % (inference_output_file, jobid)
    output_infer_done = "%s_done_%d" % (inference_output_file, jobid)

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    # Split data to multiple workers
    total_load = len(infer_data)
    load_per_worker = int((total_load - 1) / num_workers) + 1
    start_position = jobid * load_per_worker
    end_position = min(start_position + load_per_worker, total_load)
    infer_data = infer_data[start_position:end_position]

    with tf.Session(
            graph=infer_model.graph, config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(
            infer_model.model, ckpt, sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 {
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder: hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Start decoding")
        nmt_utils.decode_and_evaluate(
            "infer",
            loaded_infer_model,
            sess,
            output_infer,
            ref_file=None,
            metrics=hparams.metrics,
            subword_option=hparams.subword_option,
            beam_width=hparams.beam_width,
            tgt_eos=hparams.eos,
            num_translations_per_input=hparams.num_translations_per_input)

        # Change file name to indicate the file writing is completed.
        tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)

        # Job 0 is responsible for the clean up.
        if jobid != 0: return

        # Now write all translations
        with codecs.getwriter("utf-8")(
                tf.gfile.GFile(final_output_infer, mode="wb")) as final_f:
            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id)
                while not tf.gfile.Exists(worker_infer_done):
                    utils.print_out("  waitting job %d to complete." % worker_id)
                    time.sleep(10)

                with codecs.getreader("utf-8")(
                        tf.gfile.GFile(worker_infer_done, mode="rb")) as f:
                    for translation in f:
                        final_f.write("%s" % translation)

            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id)
                tf.gfile.Remove(worker_infer_done)
Exemple #10
0
def train(hparams, scope=None, target_session=""):
    """Train a translation model."""
    log_device_placement = hparams.log_device_placement
    out_dir = hparams.out_dir
    num_train_steps = hparams.num_train_steps
    steps_per_stats = hparams.steps_per_stats
    steps_per_external_eval = hparams.steps_per_external_eval
    steps_per_eval = 10 * steps_per_stats
    avg_ckpts = hparams.avg_ckpts

    if not steps_per_external_eval:
        steps_per_external_eval = 5 * steps_per_eval

    if not hparams.attention:
        model_creator = nmt_model.Model
    else:  # Attention
        if (hparams.encoder_type == "gnmt"
                or hparams.attention_architecture in ["gnmt", "gnmt_v2"]):
            model_creator = gnmt_model.GNMTModel
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        else:
            raise ValueError("Unknown attention architecture %s" %
                             hparams.attention_architecture)

    train_model = model_helper.create_train_model(model_creator, hparams,
                                                  scope)
    eval_model = model_helper.create_eval_model(model_creator, hparams, scope)
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    # Preload data for sample decoding.
    dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
    dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
    sample_src_data = inference.load_data(dev_src_file)
    sample_tgt_data = inference.load_data(dev_tgt_file)

    summary_name = "train_log"
    model_dir = hparams.out_dir

    # Log and output files
    log_file = os.path.join(out_dir, "log_%d" % time.time())
    log_f = tf.gfile.GFile(log_file, mode="a")
    utils.print_out("# log_file=%s" % log_file, log_f)

    # TensorFlow model
    config_proto = utils.get_config_proto(
        log_device_placement=log_device_placement,
        num_intra_threads=hparams.num_intra_threads,
        num_inter_threads=hparams.num_inter_threads)
    train_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=train_model.graph)
    eval_sess = tf.Session(target=target_session,
                           config=config_proto,
                           graph=eval_model.graph)
    infer_sess = tf.Session(target=target_session,
                            config=config_proto,
                            graph=infer_model.graph)

    with train_model.graph.as_default():
        loaded_train_model, global_step = model_helper.create_or_load_model(
            train_model.model, model_dir, train_sess, "train")

    # Summary writer
    summary_writer = tf.summary.FileWriter(os.path.join(out_dir, summary_name),
                                           train_model.graph)

    # First evaluation
    run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
                  hparams, summary_writer, sample_src_data, sample_tgt_data,
                  avg_ckpts)

    last_stats_step = global_step
    last_eval_step = global_step
    last_external_eval_step = global_step

    # This is the training loop.
    stats, info, start_train_time = before_train(loaded_train_model,
                                                 train_model, train_sess,
                                                 global_step, hparams, log_f)
    while global_step < num_train_steps:
        ### Run a step ###
        start_time = time.time()
        try:
            step_result = loaded_train_model.train(train_sess)
            hparams.epoch_step += 1
        except tf.errors.OutOfRangeError:
            # Finished going through the training dataset.  Go to next epoch.
            hparams.epoch_step = 0
            utils.print_out(
                "# Finished an epoch, step %d. Perform external evaluation" %
                global_step)
            #run_sample_decode(infer_model, infer_sess, model_dir, hparams,
            #                  summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

            train_sess.run(train_model.iterator.initializer,
                           feed_dict={train_model.skip_count_placeholder: 0})
            continue

        # Process step_result, accumulate stats, and write summary
        global_step, info["learning_rate"], step_summary = update_stats(
            stats, start_time, step_result)
        summary_writer.add_summary(step_summary, global_step)

        # Once in a while, we print statistics.
        if global_step - last_stats_step >= steps_per_stats:
            last_stats_step = global_step
            is_overflow = process_stats(stats, info, global_step,
                                        steps_per_stats, log_f)
            print_step_info("  ", global_step, info,
                            _get_best_results(hparams), log_f)
            if is_overflow:
                break

            # Reset statistics
            stats = init_stats()

        if global_step - last_eval_step >= steps_per_eval:
            last_eval_step = global_step
            utils.print_out("# Save eval, global step %d" % global_step)
            utils.add_summary(summary_writer, global_step, "train_ppl",
                              info["train_ppl"])

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)

            # Evaluate on dev/test
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_internal_eval(eval_model, eval_sess, model_dir, hparams,
                              summary_writer)

        if global_step - last_external_eval_step >= steps_per_external_eval:
            last_external_eval_step = global_step

            # Save checkpoint
            loaded_train_model.saver.save(train_sess,
                                          os.path.join(out_dir,
                                                       "translate.ckpt"),
                                          global_step=global_step)
            run_sample_decode(infer_model, infer_sess, model_dir, hparams,
                              summary_writer, sample_src_data, sample_tgt_data)
            run_external_eval(infer_model, infer_sess, model_dir, hparams,
                              summary_writer)

            if avg_ckpts:
                run_avg_external_eval(infer_model, infer_sess, model_dir,
                                      hparams, summary_writer, global_step)

    # Done training
    loaded_train_model.saver.save(train_sess,
                                  os.path.join(out_dir, "translate.ckpt"),
                                  global_step=global_step)

    (result_summary, _, final_eval_metrics) = (run_full_eval(
        model_dir, infer_model, infer_sess, eval_model, eval_sess, hparams,
        summary_writer, sample_src_data, sample_tgt_data, avg_ckpts))
    print_step_info("# Final, ", global_step, info, result_summary, log_f)
    utils.print_time("# Done training!", start_train_time)

    summary_writer.close()

    utils.print_out("# Start evaluating saved best models.")
    for metric in hparams.metrics:
        best_model_dir = getattr(hparams, "best_" + metric + "_dir")
        summary_writer = tf.summary.FileWriter(
            os.path.join(best_model_dir, summary_name), infer_model.graph)
        result_summary, best_global_step, _ = run_full_eval(
            best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
            hparams, summary_writer, sample_src_data, sample_tgt_data)
        print_step_info("# Best %s, " % metric, best_global_step, info,
                        result_summary, log_f)
        summary_writer.close()

        if avg_ckpts:
            best_model_dir = getattr(hparams, "avg_best_" + metric + "_dir")
            summary_writer = tf.summary.FileWriter(
                os.path.join(best_model_dir, summary_name), infer_model.graph)
            result_summary, best_global_step, _ = run_full_eval(
                best_model_dir, infer_model, infer_sess, eval_model, eval_sess,
                hparams, summary_writer, sample_src_data, sample_tgt_data)
            print_step_info("# Averaged Best %s, " % metric, best_global_step,
                            info, result_summary, log_f)
            summary_writer.close()

    return final_eval_metrics, global_step
Exemple #11
0
    def setup(self, flags):
        # Model output directory
        out_dir = flags.out_dir
        if out_dir and not tf.gfile.Exists(out_dir):
          tf.gfile.MakeDirs(out_dir)

        # Load hparams.
        default_hparams = create_hparams(flags)
        loaded_hparams = False
        if flags.ckpt:  # Try to load hparams from the same directory as ckpt
          ckpt_dir = os.path.dirname(flags.ckpt)
          ckpt_hparams_file = os.path.join(ckpt_dir, "hparams")
          if tf.gfile.Exists(ckpt_hparams_file) or flags.hparams_path:
                # Note: for some reason this will create an empty "best_bleu" directory and copy vocab files
                hparams = create_or_load_hparams(ckpt_dir, default_hparams, flags.hparams_path, save_hparams=False)
                loaded_hparams = True

        assert loaded_hparams

        # GPU device
        config_proto = utils.get_config_proto(
            allow_soft_placement=True,
            num_intra_threads=hparams.num_intra_threads,
            num_inter_threads=hparams.num_inter_threads)
        utils.print_out(
            "# Devices visible to TensorFlow: %s"
            % repr(tf.Session(config=config_proto).list_devices()))


        # Inference indices (inference_indices is broken, but without setting it to None we'll crash)
        hparams.inference_indices = None

        # Create the graph
        model_creator = get_model_creator(hparams)
        infer_model = model_helper.create_infer_model(model_creator, hparams, scope=None)
        sess, loaded_infer_model = start_sess_and_load_model(infer_model, flags.ckpt,
                                                       hparams)

        # FIXME (bryce): Set to False to disable inference from frozen graph and run fast again
        if True:
          frozen_graph = None
          with infer_model.graph.as_default():
            output_node_names = ['hash_table_Lookup_1/LookupTableFindV2']
            other_node_names  = ['MakeIterator', 'IteratorToStringHandle', 'init_all_tables', 'NoOp', 'dynamic_seq2seq/decoder/NoOp']
            frozen_graph = tf.graph_util.convert_variables_to_constants(sess,
                                                                        tf.get_default_graph().as_graph_def(),
                                                                        output_node_names=output_node_names + other_node_names)

            # FIXME (bryce): Uncomment this block to enable tensorRT convert
            from tensorflow.python.compiler.tensorrt import trt_convert as trt
            converter = trt.TrtGraphConverter(input_graph_def=frozen_graph, nodes_blacklist=(output_node_names),
                                              is_dynamic_op=True, max_batch_size=hparams.infer_batch_size,
                                              max_beam_size=hparams.beam_width, max_src_seq_len=hparams.src_max_len)
            frozen_graph = converter.convert()

          with tf.Graph().as_default():
            tf.graph_util.import_graph_def(frozen_graph, name="")
            sess = tf.Session(graph=tf.get_default_graph(),
                   config=utils.get_config_proto(
                   num_intra_threads=hparams.num_intra_threads,
                   num_inter_threads=hparams.num_inter_threads)
                   )
            iterator = iterator_utils.BatchedInput(
              initializer=tf.get_default_graph().get_operation_by_name(infer_model.iterator.initializer.name),
              source=tf.get_default_graph().get_tensor_by_name(infer_model.iterator.source.name),
              target_input=None,
              target_output=None,
              source_sequence_length=tf.get_default_graph().get_tensor_by_name(infer_model.iterator.source_sequence_length.name),
              target_sequence_length=None)
            infer_model = model_helper.InferModel(
                  graph=tf.get_default_graph(),
                  model=infer_model.model,
                  src_placeholder=tf.get_default_graph().get_tensor_by_name(infer_model.src_placeholder.name),
                  batch_size_placeholder=tf.get_default_graph().get_tensor_by_name(infer_model.batch_size_placeholder.name),
                  iterator=iterator)

        # Parameters needed by TF GNMT
        self.hparams = hparams

        self.infer_model = infer_model
        self.sess = sess
        self.loaded_infer_model = loaded_infer_model