コード例 #1
0
ファイル: translate.py プロジェクト: RChandrasekar/tensorflow
def decode():
    with tf.Session() as sess:
        # Create model and load parameters.
        model = create_model(sess, True)
        model.batch_size = 1  # We decode one sentence at a time.

        # Load vocabularies.
        en_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.en" % FLAGS.en_vocab_size)
        fr_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.fr" % FLAGS.fr_vocab_size)
        en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
        _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)

        # Decode from standard input.
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            # Get token-ids for the input sentence.
            token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)
            # Which bucket does it belong to?
            bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)])
            # Get a 1-element batch to feed the sentence to the model.
            encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id)
            # Get output logits for the sentence.
            _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
            # This is a greedy decoder - outputs are just argmaxes of output_logits.
            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            # If there is an EOS symbol in outputs, cut them at that point.
            if data_utils.EOS_ID in outputs:
                outputs = outputs[: outputs.index(data_utils.EOS_ID)]
            # Print out French sentence corresponding to outputs.
            print(" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs]))
            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
コード例 #2
0
    def decode(self, sentence):

        # Get token-ids for the input sentence.
        token_ids = data_utils.sentence_to_token_ids(
            tf.compat.as_bytes(sentence), self.en_vocab)
        # Which bucket does it belong to?
        bucket_id = min([
            b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)
        ])
        # Get a 1-element batch to feed the sentence to the model.
        encoder_inputs, decoder_inputs, target_weights = self.model.get_batch(
            {bucket_id: [(token_ids, [])]}, bucket_id)
        # Get output logits for the sentence.
        _, _, output_logits = self.model.step(self.sess, encoder_inputs,
                                              decoder_inputs, target_weights,
                                              bucket_id, True)
        # This is a greedy decoder - outputs are just argmaxes of output_logits.
        outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
        # If there is an EOS symbol in outputs, cut them at that point.
        if data_utils.EOS_ID in outputs:
            outputs = outputs[:outputs.index(data_utils.EOS_ID)]

        # Print out French sentence corresponding to outputs.
        #print(" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs]))
        translated_str = " ".join(
            [tf.compat.as_str(self.fr_vocab[output]) for output in outputs])

        return translated_str
コード例 #3
0
    def get_response(self, sentence):
        # Get token-ids for the input sentence.
        token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), self.source_vocab)
        # Which bucket does it belong to?
        bucket_id = len(_buckets) - 1
        for i, bucket in enumerate(_buckets):
            if bucket[0] >= len(token_ids):
                bucket_id = i
                break
        else:
            logging.warning("Sentence truncated: %s", sentence)

        # Get a 1-element batch to feed the sentence to the model.
        encoder_inputs, decoder_inputs, target_weights = self.model.get_batch(
            {bucket_id: [(token_ids, [])]}, bucket_id)
        # Get output logits for the sentence.
        _, _, output_logits = self.model.step(self.sess, encoder_inputs, decoder_inputs,
                                         target_weights, bucket_id, True)
        # This is a greedy decoder - outputs are just argmaxes of output_logits.
        output_with_temp = [logit / FLAGS.temperature for logit in output_logits]
        output_softmax = [np.exp(logit - logit.max()) for logit in output_with_temp]
        output_softmax = [exponential / np.sum(exponential) for exponential in output_softmax]
        top_10 = [np.sort(logit[0])[-10:] for logit in output_softmax]
        # print(top_10[0])
        outputs = [np.random.choice(FLAGS.vocab_size, 1, p=logits[0])[0] for logits in output_softmax]

        # outputs = [int(np.argmax(logit, axis=1)) for logit in output_softmax]
        # outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
        # If there is an EOS symbol in outputs, cut them at that point.
        if data_utils.EOS_ID in outputs:
            outputs = outputs[:outputs.index(data_utils.EOS_ID)]
        # Print out sentence corresponding to outputs.
        return " ".join([tf.compat.as_str(self.rev_target_vocab[output]) for output in outputs])
コード例 #4
0
ファイル: pos.py プロジェクト: ruhulsbu/tensorflow
def decode():
    print("Decoding")
    with tf.Session() as sess:
        # Create model and load parameters.
        model = create_model(sess, True)
        model.batch_size = 1  # We decode one sentence at a time.

        # Load vocabularies.
        en_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.txt" % FLAGS.en_vocab_size)
        fr_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.tags" % FLAGS.fr_vocab_size)
        en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
        _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)

        test_file_path = os.path.join(FLAGS.data_dir, "test_pos.txt")

        # Decode from standard input.
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = "He reckons the current account deficit will narrow to only # 1.8 billion in September ."
        print("Reading Test File from: " + test_file_path)

        read_test_file = open(test_file_path, "r")
        for sentence in read_test_file:
            if len(sentence) == 0:
                continue
            # while True:
            print("\nSentence = " + sentence)
            tokenized_list = sentence.strip().split()
            print(tokenized_list)
            print("Length of Tokenized Words: " + str(len(tokenized_list)))
            print("Tokenized with Basic Tokenizer")
            # Get token-ids for the input sentence.
            token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab)
            # print (token_ids)
            if len(token_ids) == 0:
                continue
            # Which bucket does it belong to?
            # bucket_id = min([b for b in xrange(len(_buckets))
            #                 if _buckets[b][0] > len(token_ids)])
            bucket_array = [b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)]
            if len(bucket_array) == 0:
                continue
            bucket_id = min(bucket_array)
            # Get a 1-element batch to feed the sentence to the model.
            encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id)
            # Get output logits for the sentence.
            _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
            # This is a greedy decoder - outputs are just argmaxes of output_logits.
            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            # If there is an EOS symbol in outputs, cut them at that point.
            if data_utils.EOS_ID in outputs:
                outputs = outputs[: outputs.index(data_utils.EOS_ID)]
            print("Final Output: ")
            print("______________")
            print(outputs)
            # Print out French sentence corresponding to outputs.
            print(" ".join([rev_fr_vocab[output] for output in outputs]))
            print("Total Length of Tags: " + str(len(outputs)))

            print("\n> ", end="\n")
コード例 #5
0
ファイル: pos.py プロジェクト: ruhulsbu/tensorflow
def evaluate_sentence(model, sess):
    b = model.batch_size
    model.batch_size = 1  # We decode one sentence at a time.
    # Load vocabularies.
    en_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.txt" % FLAGS.en_vocab_size)
    fr_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.tags" % FLAGS.fr_vocab_size)
    en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
    _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)

    sentence = "He reckons the current account deficit will narrow to only # 1.8 billion in September ."
    print(sentence)
    # Get token-ids for the input sentence.
    token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab)
    # Which bucket does it belong to?
    bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)])
    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id)
    # Get output logits for the sentence.
    _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
    # This is a greedy decoder - outputs are just argmaxes of output_logits.
    outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
    # If there is an EOS symbol in outputs, cut them at that point.
    if data_utils.EOS_ID in outputs:
        outputs = outputs[: outputs.index(data_utils.EOS_ID)]
        print(outputs)
        # Print out French sentence corresponding to outputs.
        print(" ".join([rev_fr_vocab[output] for output in outputs]))
    model.batch_size = b
コード例 #6
0
ファイル: translate.py プロジェクト: ivan2110/Translator2016
def test():
  """Test the translation model."""
  nltk.download('punkt')
  with tf.Session() as sess:
    model = create_model(sess, True)
    model.batch_size = 1  # We decode one sentence at a time.

    # Load vocabularies.
    src_lang_vocab_path = PATH_TO_DATA_FILES + FLAGS.src_lang + "_mapping%d.txt" % FLAGS.src_lang_vocab_size
    dst_lang_vocab_path = PATH_TO_DATA_FILES + FLAGS.dst_lang + "_mapping%d.txt" % FLAGS.dst_lang_vocab_size
    src_lang_vocab, _ = data_utils.initialize_vocabulary(src_lang_vocab_path)
    _, rev_dst_lang_vocab = data_utils.initialize_vocabulary(dst_lang_vocab_path)

    weights = [0.25, 0.25, 0.25, 0.25]

    first_lang_file = open(generate_src_lang_sentences_file_name(FLAGS.src_lang))
    second_lang_file = open(generate_src_lang_sentences_file_name(FLAGS.dst_lang))
		
    total_bleu_value = 0.0
    computing_bleu_iterations = 0

    for first_lang_raw in first_lang_file:
      second_lang_gold_raw = second_lang_file.readline()
      # Get token-ids for the input sentence.
      token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(first_lang_raw), src_lang_vocab)
      # Which bucket does it belong to?
      try:
        bucket_id = min([b for b in xrange(len(_buckets))
                         if _buckets[b][0] > len(token_ids)])
      except ValueError:
        continue
      # Get a 1-element batch to feed the sentence to the model.
      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
	  {bucket_id: [(token_ids, [])]}, bucket_id)
      # Get output logits for the sentence.
      _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
      # This is a greedy decoder - outputs are just argmaxes of output_logits.
      outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
      # If there is an EOS symbol in outputs, cut them at that point.
      if data_utils.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
      # Print out sentence corresponding to outputs.
      model_tran_res = " ".join([tf.compat.as_str(rev_dst_lang_vocab[output]) for output in outputs])
      second_lang_gold_tokens = word_tokenize(second_lang_gold_raw)
      model_tran_res_tokens = word_tokenize(model_tran_res)
      try:
        current_bleu_value = sentence_bleu([model_tran_res_tokens], second_lang_gold_tokens, weights)
        total_bleu_value += current_bleu_value
        computing_bleu_iterations += 1
      except ZeroDivisionError:
        pass
      if computing_bleu_iterations % 10 == 0:
        print("BLEU value after %d iterations: %.2f"
              % (computing_bleu_iterations, total_bleu_value / computing_bleu_iterations))
    final_bleu_value = total_bleu_value / computing_bleu_iterations
    print("Final BLEU value after %d iterations: %.2f" % (computing_bleu_iterations, final_bleu_value))
    return
コード例 #7
0
ファイル: translate.py プロジェクト: gvteja/ICTC
def decode():
    with tf.Session() as sess:
        # Create model and load parameters.
        model = create_model(sess, True)
        model.batch_size = 1  # We decode one sentence at a time.

        # Load vocabularies.
        en_vocab_path = os.path.join(FLAGS.data_dir,
                                     "vocab%d.en" % FLAGS.en_vocab_size)
        fr_vocab_path = os.path.join(FLAGS.data_dir,
                                     "vocab%d.fr" % FLAGS.fr_vocab_size)
        en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
        _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)

        # Decode from standard input.
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            # Get token-ids for the input sentence.
            token_ids = data_utils.sentence_to_token_ids(
                tf.compat.as_bytes(sentence), en_vocab)
            # Which bucket does it belong to?

            try:
                bucket_id = min([
                    b for b in xrange(len(_buckets))
                    if _buckets[b][0] > len(token_ids)
                ])
                # Get a 1-element batch to feed the sentence to the model.
                encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                    {bucket_id: [(token_ids, [])]}, bucket_id)
                # Get output logits for the sentence.
                _, _, output_logits = model.step(sess, encoder_inputs,
                                                 decoder_inputs,
                                                 target_weights, bucket_id,
                                                 True)
                # This is a greedy decoder - outputs are just argmaxes of output_logits.
                outputs = [
                    int(np.argmax(logit, axis=1)) for logit in output_logits
                ]
                # If there is an EOS symbol in outputs, cut them at that point.
                if data_utils.EOS_ID in outputs:
                    outputs = outputs[:outputs.index(data_utils.EOS_ID)]

# Print out French sentence corresponding to outputs.
                print(" ".join([
                    tf.compat.as_str(rev_fr_vocab[output])
                    for output in outputs
                ]))
            except:
                print("Exception: input too long")
                pass
            finally:
                print("> ", end="")
                sys.stdout.flush()
                sentence = sys.stdin.readline()
コード例 #8
0
def decode(data_dir, train_dir):
    with tf.Session() as sess:
        # Create model and load parameters.
        model = create_model(sess, train_dir, True)
        model.batch_size = 1  # We decode one sentence at a time.

        # Load vocabularies.
        que_vocab = pickle.load(open(os.path.join(data_dir, "w2idx_q"), "rb"))
        ans_vocab = pickle.load(open(os.path.join(data_dir, "w2idx_a"), "rb"))
        # Index Changing here.
        ans_vocab["_go_"] = 1
        ans_vocab["_eos_"] = 2
        que_vocab["."] = 6002
        ans_vocab["."] = 6002
        que_vocab["the"] = 6003
        ans_vocab["the"] = 6003
        rev_ans_vocab = {v: k for (k, v) in ans_vocab.items()}

        # Decode from standard input.
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()

        while sentence:
            # Get token-ids for the input sentence.
            token_ids = data_utils.sentence_to_token_ids(
                tf.compat.as_bytes(sentence), que_vocab)
            # Which bucket does it belong to?
            bucket_id = min([
                b for b in xrange(len(_buckets))
                if _buckets[b][0] > len(token_ids)
            ])
            # Get a 1-element batch to feed the sentence to the model.
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                {bucket_id: [(token_ids, [])]}, bucket_id)
            # Get output logits for the sentence.
            _, _, output_logits = model.step(sess, encoder_inputs,
                                             decoder_inputs, target_weights,
                                             bucket_id, True)
            # This is a greedy decoder - outputs are just argmaxes of output_logits.
            outputs = [
                int(np.argmax(logit, axis=1)) for logit in output_logits
            ]
            # If there is an EOS symbol in outputs, cut them at that point.
            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]
            # Print out French sentence corresponding to outputs.
            print(" ".join([
                tf.compat.as_str(rev_ans_vocab[output]) for output in outputs
            ]))
            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
コード例 #9
0
def sentiment_sentence(sess, gen_model, sentence, is_beam_search=False):
    sentence = sentence.rstrip('\n')
    # print("input: %s" % sentence)
    source_vocab_path, target_vocab_path = data_utils.get_source_target_vocab_path(FLAGS.dict_dir,
                                                                                   FLAGS.source_vocab_size,
                                                                                   FLAGS.target_vocab_size)
    # Load vocabularies.
    source_vocab, _ = initialize_vocabulary(source_vocab_path)
    _, rev_target_vocab = initialize_vocabulary(target_vocab_path)

    gen_model.batch_size = 1  # We decode one sentence at a time.
    # Get token-ids for the input sentence.
    token_ids = sentence_to_token_ids(sentence, source_vocab, tokenizer=basic_tokenizer)
    # Which bucket does it belong to? todo new
    seq = [b for b in range(len(_buckets))
                     if _buckets[b][0] > len(token_ids)]
    if seq:
        bucket_id = min(seq)
    else:
        bucket_id = len(_buckets)-1

    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights = gen_model.get_batch(
        {bucket_id: [(token_ids, [])]}, bucket_id)
    # Get output logits for the sentence.

    _, _, output_logits = gen_model.step(sess, encoder_inputs, decoder_inputs,
                                         target_weights, bucket_id, True)

    # TODO implement beam search
    # outputs = decoder_util.run_beam_op(sess, rev_target_vocab, decoder_inputs, output_logits)

    # This is a greedy decoder - outputs are just argmaxes of output_logits.
    # print(output_logits)
    outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]

    # If there is an EOS symbol in outputs, cut them at that point.
    print(outputs)
    if EOS_ID in outputs:
        outputs = outputs[:outputs.index(EOS_ID)]

    gen_model.batch_size = FLAGS.batch_size  # Put back to original batch_size.

    output_string = " ".join([rev_target_vocab[output] for output in outputs if not output >= len(rev_target_vocab)])
    # remove space before punctuation
    # print("output: %s" % output_string)
    output_string = re.sub(r'\s([?,.!"](?:\s|$))', r'\1', output_string)
    # print("output: %s" % output_string)
    return output_string
コード例 #10
0
ファイル: translate.py プロジェクト: stprior/tensorflow
def decode():
  with tf.Session() as sess:
    # Create model and load parameters.
    model = create_model(sess, True)
    model.batch_size = 1  # We decode one sentence at a time.

    # Load vocabularies.
    en_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.en" % FLAGS.en_vocab_size)
    fr_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.fr" % FLAGS.fr_vocab_size)
    en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
    _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)

    # Decode from standard input.
    sys.stdout.write("> ")
    sys.stdout.flush()
    sentence = sys.stdin.readline()
    while sentence:
      # Get token-ids for the input sentence.
      token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)
      # Which bucket does it belong to?
      bucket_id = len(_buckets) - 1
      for i, bucket in enumerate(_buckets):
        if bucket[0] >= len(token_ids):
          bucket_id = i
          break
      else:
        logging.warning("Sentence truncated: %s", sentence) 

      # Get a 1-element batch to feed the sentence to the model.
      encoder_inputs, decoder_inputs, target_weights = model.get_batch(
          {bucket_id: [(token_ids, [])]}, bucket_id)
      # Get output logits for the sentence.
      _, _, output_logits, hidden_states = model.step_with_states(sess, encoder_inputs, decoder_inputs,
                                       target_weights, bucket_id, True)
      # This is a greedy decoder - outputs are just argmaxes of output_logits.
      outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
      # If there is an EOS symbol in outputs, cut them at that point.
      if data_utils.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
      # Print out French sentence corresponding to outputs.
      print(" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs]))
      print(" ".join([summarise_state(state) for state in hidden_states]))
      print("> ", end="")
      sys.stdout.flush()
      sentence = sys.stdin.readline()
コード例 #11
0
        def single_sentence_decoding(sentence):

            # Get token-ids for the input sentence.
            token_ids = data_utils.sentence_to_token_ids(
                tf.compat.as_bytes(sentence), en_vocab)
            # Which bucket does it belong to?
            bucket_id = min([
                b for b in xrange(len(_buckets))
                if _buckets[b][0] > len(token_ids)
            ])
            # Get a 1-element batch to feed the sentence to the model.
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                {bucket_id: [(token_ids, [])]}, bucket_id)
            # Get output logits for the sentence.
            _, _, output_logits = model.step(sess, encoder_inputs,
                                             decoder_inputs, target_weights,
                                             bucket_id, True)

            return decode_once(output_logits, rev_fr_vocab)
コード例 #12
0
def seq2seq(echo):
	print(echo)
	sentence = echo
	if not(sess) or not(model):
		return statement(render_template('notyet'))
	# Get token-ids for the input sentence.
	token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), que_vocab)
	# Which bucket does it belong to?
	bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)])
	# Get a 1-element batch to feed the sentence to the model.
	encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id)
	# Get output logits for the sentence.
	_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,target_weights, bucket_id, True)
	# This is a greedy decoder - outputs are just argmaxes of output_logits.
	outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
	# If there is an EOS symbol in outputs, cut them at that point.
	if data_utils.EOS_ID in outputs:
		outputs = outputs[:outputs.index(data_utils.EOS_ID)]
	response = " ".join([tf.compat.as_str(rev_ans_vocab[output]) for output in outputs])
	print(response)
	return statement(response)
コード例 #13
0
ファイル: server.py プロジェクト: WoodyZantzinger/neural_chat
def decode(sentence):
    # Get token-ids for the input sentence.
    token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)

    # Which bucket does it belong to?
    bucket_id = min([b for b in xrange(len(_buckets))
                   if _buckets[b][0] > len(token_ids)])

    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights = model.get_batch(
      {bucket_id: [(token_ids, [])]}, bucket_id)

    # Get output logits for the sentence.
    _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                   target_weights, bucket_id, True)

    # This is a greedy decoder - outputs are just argmaxes of output_logits.
    outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
    if data_utils.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
    # Return sentence corresponding to outputs.
    return (" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs]))
コード例 #14
0
def translate_add(self, sentence):
    global sess
    global model
    #with tf.Session() as sess:
    # Create model and load parameters.
    #model = create_model(self.sess, True)
    model.batch_size = 1  # We decode one sentence at a time.

    # Load vocabularies.
    qt_vocab_path = os.path.join(FLAGS.data_dir,
                                 "vocab%d.qt" % FLAGS.qt_vocab_size)
    ans_vocab_path = os.path.join(FLAGS.data_dir,
                                  "vocab%d.ans" % FLAGS.ans_vocab_size)
    qt_vocab, _ = data_utils.initialize_vocabulary(qt_vocab_path)
    _, rev_ans_vocab = data_utils.initialize_vocabulary(ans_vocab_path)

    # Get token-ids for the input sentence.
    token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence),
                                                 qt_vocab)
    # Which bucket does it belong to?
    bucket_id = min(
        [b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)])
    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights = model.get_batch(
        {bucket_id: [(token_ids, [])]}, bucket_id)
    # Get output logits for the sentence.
    _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                     target_weights, bucket_id, True)
    # This is a greedy decoder - outputs are just argmaxes of output_logits.
    outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
    # If there is an EOS symbol in outputs, cut them at that point.
    if data_utils.EOS_ID in outputs:
        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
        # Print out answer sentence corresponding to outputs.
    result = " ".join(
        [tf.compat.as_str(rev_ans_vocab[output]) for output in outputs])
    print("Server sent data:%s" % result)
    return result