def __init__(self, bot):
        logging.basicConfig(level=logging.INFO)
        out_dir = '<path_to_model>'

        self.ping_replace = re.compile(r"<@![0-9]{2,}>", re.IGNORECASE)
        self.bot = bot

        nmt_parser = argparse.ArgumentParser()
        add_arguments(nmt_parser)
        flags, unparsed = nmt_parser.parse_known_args()
        default_hparams = create_hparams(flags)
        self.hparams = create_or_load_hparams(out_dir,
                                              default_hparams,
                                              flags.hparams_path,
                                              save_hparams=False)
        ckpt = tf.train.latest_checkpoint(out_dir)

        if not self.hparams.attention:
            model_creator = nmt_model.Model
        elif self.hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        elif self.hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
            model_creator = gnmt_model.GNMTModel
        else:
            raise ValueError("Unknown model architecture")
        self.infer_model = model_helper.create_infer_model(
            model_creator, self.hparams, None)
        self.session = tf.InteractiveSession(graph=self.infer_model.graph,
                                             config=utils.get_config_proto())
        self.loaded_infer_model = model_helper.load_model(
            self.infer_model.model, ckpt, self.session, "infer")
示例#2
0
    def nmt_main(self, flags, default_hparams, scope=None):
        out_dir = flags.out_dir
        if not tf.gfile.Exists(out_dir):
            tf.gfile.MakeDirs(out_dir)

        self.hparams = nmt.create_or_load_hparams(out_dir,
                                                  default_hparams,
                                                  flags.hparams_path,
                                                  save_hparams=False)

        self.ckpt = tf.train.latest_checkpoint(out_dir)
        if not self.ckpt:
            print('Train is needed')
            sys.exit()

        hparams = self.hparams
        if not hparams.attention:
            model_creator = model.Model
        elif hparams.attention_architecture == "standard":
            model_creator = attention_model.AttentionModel
        elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
            model_creator = gnmt_model.GNMTModel
        else:
            raise ValueError("Unknown model architecture")
        self.infer_model = model_helper.create_infer_model(
            model_creator, hparams, scope)

        self.sess = tf.Session(graph=self.infer_model.graph,
                               config=utils.get_config_proto())

        with self.infer_model.graph.as_default():
            self.loaded_infer_model = model_helper.load_model(
                self.infer_model.model, self.ckpt, self.sess, 'infer')
示例#3
0
def predicate(ckpt,
              hparams,
              num_workers=1,
              jobid=0,
              scope=None):
    """Perform translation."""
    if hparams.inference_indices:
        assert num_workers == 1

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")
    infer_model = model_helper.create_infer_model(model_creator, hparams, scope)

    with tf.Session(graph=infer_model.graph,
                    config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(
            infer_model.model, ckpt, sess, "infer")
        while True:
            input_data = input("translate>")
            res = translate_and_return(hparams,
                                       infer_model,
                                       [input_data.lower()],
                                       loaded_infer_model,
                                       sess)
            print("result: %s" % res.decode("utf-8"))
示例#4
0
 def __init__(self):
     hparams = load_hparams('/tmp/nmt_model')
     ckpt = tf.train.latest_checkpoint('/tmp/nmt_model')
     self.model = create_infer_model(Model, hparams)
     self.sess = tf.Session(graph=self.model.graph,
                            config=get_config_proto())
     with self.model.graph.as_default():
         self.loaded_infer_model = load_model(self.model.model, ckpt,
                                              self.sess, "infer")
示例#5
0
def chpt_to_dict_arrays():
    """
    Convert a checkpoint into a dictionary of numpy arrays
    for later use in TensorRT NMT sample.
    git clone https://github.com/tensorflow/nmt.git
    """
    sys.path.append("./nmt")
    from nmt.nmt import add_arguments, create_hparams
    from nmt import attention_model
    from nmt import model_helper
    from nmt.nmt import create_or_load_hparams
    from nmt import utils
    from nmt import model as nmt_model

    nmt_parser = argparse.ArgumentParser()
    add_arguments(nmt_parser)
    FLAGS, unparsed = nmt_parser.parse_known_args()

    default_hparams = create_hparams(FLAGS)

    hparams = create_or_load_hparams(FLAGS.out_dir,
                                     default_hparams,
                                     FLAGS.hparams_path,
                                     save_hparams=False)

    print(hparams)

    model_creator = None
    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    else:
        raise ValueError("Unknown model architecture")

    infer_model = model_helper.create_infer_model(model_creator,
                                                  hparams,
                                                  scope=None)

    params = {}
    print("\nFound the following trainable variables:")
    with tf.Session(graph=infer_model.graph,
                    config=utils.misc_utils.get_config_proto()) as sess:

        loaded_infer_model = model_helper.load_model(infer_model.model,
                                                     FLAGS.ckpt, sess, "infer")

        variables = tf.trainable_variables()
        for v in variables:
            params[v.name] = v.eval(session=sess)
            print("{0}    {1}".format(v.name, params[v.name].shape))

    params["forget_bias"] = hparams.forget_bias
    return params
示例#6
0
def single_worker_inference(infer_model,
                            ckpt,
                            inference_input_file,
                            inference_output_file,
                            hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    infer_data = ["Lúc đấy tôi nghĩ chuyện này sẽ khó khăn gian khổ đây ."]

    with tf.Session(
            graph=infer_model.graph, config=utils.get_config_proto()) as sess:
        while True:
            # infer_data = ["Lúc đấy tôi nghĩ chuyện này sẽ khó khăn gian khổ đây ."]
            var = input("Input Vi Src: ")
            infer_data = [var]
            loaded_infer_model = model_helper.load_model(
                infer_model.model, ckpt, sess, "infer")
            sess.run(
                infer_model.iterator.initializer,
                feed_dict={
                    infer_model.src_placeholder: infer_data,
                    infer_model.batch_size_placeholder: hparams.infer_batch_size
                })
            # Decode
            utils.print_out("# Start decoding")
            if hparams.inference_indices:
                _decode_inference_indices(
                    loaded_infer_model,
                    sess,
                    output_infer=output_infer,
                    output_infer_summary_prefix=output_infer,
                    inference_indices=hparams.inference_indices,
                    tgt_eos=hparams.eos,
                    subword_option=hparams.subword_option)
            else:
                nmt_utils.decode_and_evaluate(
                    "infer",
                    loaded_infer_model,
                    sess,
                    output_infer,
                    ref_file=None,
                    metrics=hparams.metrics,
                    subword_option=hparams.subword_option,
                    beam_width=hparams.beam_width,
                    tgt_eos=hparams.eos,
                    num_translations_per_input=hparams.num_translations_per_input)
示例#7
0
def translate(ckpt,
              infer_data,
              inference_output_file,
              hparams,
              num_workers=1,
              jobid=0,
              scope=None):
    """Inference with a single worker."""
    output_infer = inference_output_file

    """Perform translation."""
    if hparams.inference_indices:
        assert num_workers == 1

    if not hparams.attention:
        model_creator = nmt_model.Model
    elif hparams.attention_architecture == "standard":
        model_creator = attention_model.AttentionModel
    elif hparams.attention_architecture in ["gnmt", "gnmt_v2"]:
        model_creator = gnmt_model.GNMTModel
    else:
        raise ValueError("Unknown model architecture")
    infer_model = model_helper.create_infer_model(model_creator, hparams, scope)

    with tf.Session(
            graph=infer_model.graph, config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(
            infer_model.model, ckpt, sess, "infer")
        # Encode Data
        sess.run(
            infer_model.iterator.initializer,
            feed_dict={
                infer_model.src_placeholder: infer_data,
                infer_model.batch_size_placeholder: hparams.infer_batch_size
            })
        # Decode
        utils.print_out("# Start decoding")
        return nmt_utils.decode_and_return(
            "infer",
            loaded_infer_model,
            sess,
            output_infer,
            ref_file=None,
            metrics=hparams.metrics,
            subword_option=hparams.subword_option,
            beam_width=hparams.beam_width,
            tgt_eos=hparams.eos,
            num_translations_per_input=hparams.num_translations_per_input)
示例#8
0
def generate_reply(input_text, flags):
    # Format data
    tokenized_text = tokenize_text(input_text)
    infer_data = [tokenized_text]
    # Load hparams.
    jobid = flags.jobid
    default_hparams = nmt.create_hparams(flags)
    hparams = nmt.create_or_load_hparams(model_dir,
                                         default_hparams,
                                         flags.hparams_path,
                                         save_hparams=(jobid == 0))
    # Load checkpoint
    ckpt = tf.train.latest_checkpoint(model_dir)
    # Inference
    model_creator = attention_model.AttentionModel
    # Create model
    scope = None
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    with tf.Session(graph=infer_model.graph,
                    config=misc_utils.get_config_proto()) as sess:
        model = model_helper.load_model(infer_model.model, ckpt, sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        nmt_outputs, infer_summary = model.decode(sess)
        # get text translation(reply as a chatbot)
        assert nmt_outputs.shape[0] == 1
        translation = nmt_utils.get_translation(
            nmt_outputs,
            sent_id=0,
            tgt_eos=hparams.eos,
            subword_option=hparams.subword_option)

    return translation.decode("utf-8")
示例#9
0
def multi_worker_inference(infer_model,
                           ckpt,
                           inference_input_file,
                           inference_output_file,
                           hparams,
                           num_workers,
                           jobid):
    """Inference using multiple workers."""
    assert num_workers > 1

    final_output_infer = inference_output_file
    output_infer = "%s_%d" % (inference_output_file, jobid)
    output_infer_done = "%s_done_%d" % (inference_output_file, jobid)

    # Read data
    infer_data = load_data(inference_input_file, hparams)

    # Split data to multiple workers
    total_load = len(infer_data)
    load_per_worker = int((total_load - 1) / num_workers) + 1
    start_position = jobid * load_per_worker
    end_position = min(start_position + load_per_worker, total_load)
    infer_data = infer_data[start_position:end_position]

    with tf.Session(
            graph=infer_model.graph, config=utils.get_config_proto()) as sess:
        loaded_infer_model = model_helper.load_model(
            infer_model.model, ckpt, sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 {
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder: hparams.infer_batch_size
                 })
        # Decode
        utils.print_out("# Start decoding")
        nmt_utils.decode_and_evaluate(
            "infer",
            loaded_infer_model,
            sess,
            output_infer,
            ref_file=None,
            metrics=hparams.metrics,
            subword_option=hparams.subword_option,
            beam_width=hparams.beam_width,
            tgt_eos=hparams.eos,
            num_translations_per_input=hparams.num_translations_per_input)

        # Change file name to indicate the file writing is completed.
        tf.gfile.Rename(output_infer, output_infer_done, overwrite=True)

        # Job 0 is responsible for the clean up.
        if jobid != 0: return

        # Now write all translations
        with codecs.getwriter("utf-8")(
                tf.gfile.GFile(final_output_infer, mode="wb")) as final_f:
            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id)
                while not tf.gfile.Exists(worker_infer_done):
                    utils.print_out("  waitting job %d to complete." % worker_id)
                    time.sleep(10)

                with codecs.getreader("utf-8")(
                        tf.gfile.GFile(worker_infer_done, mode="rb")) as f:
                    for translation in f:
                        final_f.write("%s" % translation)

            for worker_id in range(num_workers):
                worker_infer_done = "%s_done_%d" % (inference_output_file, worker_id)
                tf.gfile.Remove(worker_infer_done)