Esempio n. 1
0
    def reply(self, s):
        if len(s) == 0 or s == '\r' or s == '\n':
            return ''

        infer_data = [remove_special_char(s)]

        self.sess.run(self.infer_model.iterator.initializer,
                      feed_dict={
                          self.infer_model.src_placeholder:
                          infer_data,
                          self.infer_model.batch_size_placeholder:
                          self.hparams.infer_batch_size
                      })

        beam_width = self.hparams.beam_width
        num_translations_per_input = max(min(1, beam_width), 1)

        nmt_outputs, _ = self.loaded_infer_model.decode(self.sess)
        if beam_width == 0:
            nmt_outputs = np.expand_dims(nmt_outputs, 0)

        batch_size = nmt_outputs.shape[1]

        for sent_id in range(batch_size):
            for beam_id in range(num_translations_per_input):
                translation = nmt_utils.get_translation(
                    nmt_outputs[beam_id],
                    sent_id,
                    tgt_eos=self.hparams.eos,
                    subword_option=self.hparams.subword_option)

        return translation.decode('utf-8')
Esempio n. 2
0
    def translate(self, sentence_id_list):
        infer_mode = self.hparams.infer_mode

        # Set input data and batch size
        with self.infer_model.graph.as_default():
            self.sess.run(
                self.infer_model.iterator.initializer,
                feed_dict={
                    self.infer_model.src_placeholder: [self.infer_data[i] for i in sentence_id_list],
                    self.infer_model.batch_size_placeholder: min(self.hparams.infer_batch_size, len(sentence_id_list))
                })

        # Start the translation
        nmt_outputs, _ = self.loaded_infer_model.decode(self.sess)
        if infer_mode != "beam_search":
          nmt_outputs = np.expand_dims(nmt_outputs, 0)

        batch_size = nmt_outputs.shape[1]
        assert batch_size <= self.hparams.infer_batch_size

        # Whether beam search is being used or not, we only want 1 final translation
        assert self.hparams.num_translations_per_input == 1

        translation = []
        for decoded_id in range(batch_size):
            translation += [nmt_utils.get_translation(
                        nmt_outputs[0],
                       decoded_id,
                       tgt_eos=self.hparams.eos,
                       subword_option=self.hparams.subword_option)]

        # Keeping track of how many translations happened
        self.count += len(translation)

        return translation
Esempio n. 3
0
def _decode_inference_indices(model, sess, output_infer,
                              output_infer_summary_prefix,
                              inference_indices,
                              tgt_eos,
                              subword_option):
    """Decoding only a specific set of sentences."""
    utils.print_out("  decoding to output %s , num sents %d." %
                    (output_infer, len(inference_indices)))
    start_time = time.time()
    with codecs.getwriter("utf-8")(
            tf.gfile.GFile(output_infer, mode="wb")) as trans_f:
        trans_f.write("")  # Write empty string to ensure file is created.
        for decode_id in inference_indices:
            nmt_outputs, infer_summary = model.decode(sess)

            # get text translation
            assert nmt_outputs.shape[0] == 1
            translation = nmt_utils.get_translation(
                nmt_outputs,
                sent_id=0,
                tgt_eos=tgt_eos,
                subword_option=subword_option)

            if infer_summary is not None:  # Attention models
                image_file = output_infer_summary_prefix + str(decode_id) + ".png"
                utils.print_out("  save attention image to %s*" % image_file)
                image_summ = tf.Summary()
                image_summ.ParseFromString(infer_summary)
                with tf.gfile.GFile(image_file, mode="w") as img_f:
                    img_f.write(image_summ.value[0].image.encoded_image_string)

            trans_f.write("%s\n" % translation)
            utils.print_out(translation + b"\n")
    utils.print_time("  done", start_time)
Esempio n. 4
0
def _sample_decode(model, global_step, sess, hparams, iterator, src_data,
                   tgt_data, iterator_src_placeholder,
                   iterator_batch_size_placeholder, summary_writer):
    """Pick a sentence and decode."""
    decode_id = random.randint(0, len(src_data) - 1)
    utils.print_out("  # %d" % decode_id)

    iterator_feed_dict = {
        iterator_src_placeholder: [src_data[decode_id]],
        iterator_batch_size_placeholder: 1,
    }
    sess.run(iterator.initializer, feed_dict=iterator_feed_dict)

    nmt_outputs, attention_summary = model.decode(sess)

    if hparams.beam_width > 0:
        # get the top translation.
        nmt_outputs = nmt_outputs[0]

    translation = nmt_utils.get_translation(
        nmt_outputs,
        sent_id=0,
        tgt_eos=hparams.eos,
        subword_option=hparams.subword_option)
    utils.print_out("    src: %s" % src_data[decode_id])
    utils.print_out("    ref: %s" % tgt_data[decode_id])
    utils.print_out(b"    nmt: " + translation)

    # Summary
    if attention_summary is not None:
        summary_writer.add_summary(attention_summary, global_step)
Esempio n. 5
0
    def _assertBeamSearchOutputs(self, m, sess, assert_top_k_sentence, name):
        nmt_outputs, _ = m.decode(sess)

        for i in range(assert_top_k_sentence):
            output_words = nmt_outputs[i]
            for j in range(output_words.shape[0]):
                sentence = nmt_utils.get_translation(output_words,
                                                     j,
                                                     tgt_eos=EOS,
                                                     subword_option='')
                sentence_key = ('%s: batch %d of beam %d' % (name, j, i))
                self.actual_beam_sentences[sentence_key] = sentence
                expected_sentence = self.expected_beam_sentences[sentence_key]
                self.assertEqual(expected_sentence, sentence)
    async def talk(self, message):
        logging.info('MSG: ' + message.content + ' in ' + message.channel.name)

        luck = random.random()
        if (luck > 0):
            logging.info('trigerred')
            start_time = time.time()
            try:
                await self.bot.send_typing(message.channel)
            except:
                pass  # f**k it

            self.session.run(self.infer_model.iterator.initializer,
                             feed_dict={
                                 self.infer_model.src_placeholder:
                                 [message.content],
                                 self.infer_model.batch_size_placeholder:
                                 self.hparams.infer_batch_size
                             })
            num_translations_per_input = max(
                min(self.hparams.num_translations_per_input,
                    self.hparams.beam_width), 1)

            nmt_outputs, _ = self.loaded_infer_model.decode(self.session)
            if self.hparams.beam_width == 0:
                nmt_outputs = np.expand_dims(nmt_outputs, 0)

            batch_size = nmt_outputs.shape[1]

            for sent_id in range(batch_size):
                for beam_id in range(num_translations_per_input):
                    response = nmt_utils.get_translation(
                        nmt_outputs[beam_id],
                        sent_id,
                        tgt_eos=self.hparams.eos,
                        subword_option=self.hparams.subword_option)
                    end_time = time.time()
                    logging.info('Time taken for response:' +
                                 str(end_time - start_time))

                    #clean_msg = self.ping_replace.sub('', response)
                    clean_msg = str(response,
                                    'utf-8').replace('<unk>',
                                                     '').replace('\n',
                                                                 '').strip()
                    await self.bot.send_message(message.channel, clean_msg)
Esempio n. 7
0
def generate_reply(input_text, flags):
    # Format data
    tokenized_text = tokenize_text(input_text)
    infer_data = [tokenized_text]
    # Load hparams.
    jobid = flags.jobid
    default_hparams = nmt.create_hparams(flags)
    hparams = nmt.create_or_load_hparams(model_dir,
                                         default_hparams,
                                         flags.hparams_path,
                                         save_hparams=(jobid == 0))
    # Load checkpoint
    ckpt = tf.train.latest_checkpoint(model_dir)
    # Inference
    model_creator = attention_model.AttentionModel
    # Create model
    scope = None
    infer_model = model_helper.create_infer_model(model_creator, hparams,
                                                  scope)

    with tf.Session(graph=infer_model.graph,
                    config=misc_utils.get_config_proto()) as sess:
        model = model_helper.load_model(infer_model.model, ckpt, sess, "infer")
        sess.run(infer_model.iterator.initializer,
                 feed_dict={
                     infer_model.src_placeholder: infer_data,
                     infer_model.batch_size_placeholder:
                     hparams.infer_batch_size
                 })
        # Decode
        nmt_outputs, infer_summary = model.decode(sess)
        # get text translation(reply as a chatbot)
        assert nmt_outputs.shape[0] == 1
        translation = nmt_utils.get_translation(
            nmt_outputs,
            sent_id=0,
            tgt_eos=hparams.eos,
            subword_option=hparams.subword_option)

    return translation.decode("utf-8")