Пример #1
0
def evaluate_score():
  evaluator.init()
  text_max_words = evaluator.all_distinct_texts.shape[1]
  print('text_max_words:', text_max_words)
  predictor = algos_factory.gen_predictor(FLAGS.algo)
  predictor.init_predict(text_max_words)
  predictor.load(FLAGS.model_dir)

  evaluator.evaluate_scores(predictor)
Пример #2
0
def evaluate_score():
  predictor = algos_factory.gen_predictor(FLAGS.algo)
  score = predictor.init_predict(TEXT_MAX_WORDS)
  tf.add_to_collection('score', score)
  predictor.load(FLAGS.model_dir)
  step = melt.get_model_step_from_dir(FLAGS.model_dir) 
  model_dir, _ = melt.get_model_dir_and_path(FLAGS.model_dir)
  print('step', step, file=sys.stderr)
  print('model_dir', model_dir)
  #melt.save_model(melt.get_session(), FLAGS.model_dir, step + 1)
  melt.save_model(melt.get_session(), model_dir, step + 1)
Пример #3
0
def predict():
  predictor = algos_factory.gen_predictor(algo)
  predictor.init_predict(TEXT_MAX_WORDS)
  predictor.load('./model.ckpt-12000')

  for line in sys.stdin:
    l = line.strip().split('\t')
    image_name = l[0]
    image_feature = np.array([[float(x) for x in l[1:]]])
    #image_feature = [[float(x) for x in l[1:]]]

    scores = predictor.bulk_predict(image_feature, ids_list)[0]

    for i, score in enumerate(scores):
      print('{}\t{}\t{}'.format(image_name, score, text_list[i]))
Пример #4
0
def main(_):
  text2ids.init()

  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam_search, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir) 

  predict(predictor, "����̫����ô����")
  predict(predictor, "���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���")
  predict(predictor, "����������ʵ��С��ô��,����������ʵ��С���δ�ʩ")
Пример #5
0
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  global sess
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      text, score, beam_text, beam_score = gen_predict_graph(predictor, scope)

  predictor.load(FLAGS.model_dir) 
  #input_text = "������������_��������ǰ��Ա���Ƭ"
  input_texts = ['���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 '����̫����ô����',
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 '����ף�Ŀǰ4����1�굶']

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    timer = gezi.Timer()
    text_, score_ = sess.run([text, score], {predictor.input_text_place : [word_ids]})
    print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    texts, scores = sess.run([beam_text, beam_score], {predictor.input_text_place : [word_ids]})

    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_)

    print('beam_search using time(ms):', timer.elapsed_ms())
Пример #6
0
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      ##--notice if not add below len(tf.get_collection('encode_state') is 1, add below will be 2
      ## even though in init_predict_text(decode_method=SeqDecodeMethod.beam) will call generate_sequence_greedy
      #text, score = predictor.init_predict_text(decode_method=SeqDecodeMethod.greedy, 
      #                                          beam_size=FLAGS.beam_size,
      #                                          convert_unk=False)   
      #scope.reuse_variables()
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir, sess=sess) 

  for item in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(item)
  #input_text = "������������_��������ǰ��Ա���Ƭ"
  input_texts = [
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����̫����ô����',
                 '����ף�Ŀǰ4����1�굶',
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 ]

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    #timer = gezi.Timer()
    #text_, score_ = sess.run([text, score], {predictor.input_text_feed : [word_ids]})
    #print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    texts, scores = sess.run([beam_text, beam_score], 
                                            {predictor.input_text_feed : [word_ids]})


    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_, math.log(score_))

    print('beam_search using time(ms):', timer.elapsed_ms())

  input_texts = [
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ����', #same length as lajiao sentence 15
                 #"����̫����ô����",
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'�޺콨�ǰ���˹��',
                 ]

  word_ids_list = [_text2ids(input_text, INPUT_TEXT_MAX_WORDS) for input_text in input_texts]
  timer = gezi.Timer()
  texts_list, scores_list = sess.run([beam_text, beam_score], 
                             feed_dict={predictor.input_text_feed: word_ids_list})
  

  for texts, scores in zip(texts_list, scores_list):
    for text, score in zip(texts, scores):
      print(text, text2ids.ids2text(text), score, math.log(score))

  print('beam_search using time(ms):', timer.elapsed_ms())
Пример #7
0
import numpy as np
from libword_counter import Vocabulary
import gezi

from deepiu.image_caption.algos import algos_factory

WORDS_SEP = ' '
TEXT_MAX_WORDS = 80
NUM_RESERVED_IDS = 1
ENCODE_UNK = 0
IMAGE_FEATURE_LEN = 1000

vocabulary = Vocabulary(sys.argv[3], NUM_RESERVED_IDS)

algo = 'bow'
predictor = algos_factory.gen_predictor(algo)
predictor.init_predict(TEXT_MAX_WORDS)
predictor.load(sys.argv[2])

ids_list = []
for line in open(sys.argv[1]):
    line = line.strip().split('\t')[-1]
    words = line.split()
    ids = [
        vocabulary.id(word) for word in text.split(WORDS_SEP)
        if vocabulary.has(word) or ENCODE_UNK
    ]
    ids = gezi.pad(ids, TEXT_MAX_WORDS)
    ids_list.append(ids)

ids_list = np.array(ids_list)
Пример #8
0
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      ##--notice if not add below len(tf.get_collection('encode_state') is 1, add below will be 2
      ## even though in init_predict_text(decode_method=SeqDecodeMethod.beam) will call generate_sequence_greedy
      #text, score = predictor.init_predict_text(decode_method=SeqDecodeMethod.greedy, 
      #                                          beam_size=FLAGS.beam_size,
      #                                          convert_unk=False)   
      #scope.reuse_variables()
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir, sess=sess) 

  for item in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(item)
  #input_text = "������������_��������ǰ��Ա���Ƭ"
  input_texts = [
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����̫����ô����',
                 #'����ף�Ŀǰ4����1�굶',
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 ]

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    #timer = gezi.Timer()
    #text_, score_ = sess.run([text, score], {predictor.input_text_feed : [word_ids]})
    #print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    print(tf.get_collection('encode_state'), len(tf.get_collection('encode_state')))
    texts, scores,  atkeys, atvals  = sess.run([beam_text, beam_score, 
                                             tf.get_collection('attention_keys')[0],
                                             tf.get_collection('attention_values')[0],
                                             ], 
                                            {predictor.input_text_feed : [word_ids]})

    print(atkeys)
    print(atvals)
    print(np.shape(atkeys), np.shape(atvals))

    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_)

    print('beam_search using time(ms):', timer.elapsed_ms())

  input_texts = [
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ����', #same length as lajiao sentence 15
                 #"����̫����ô����",
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'�޺콨�ǰ���˹��',
                 ]

  word_ids_list = [_text2ids(input_text, INPUT_TEXT_MAX_WORDS) for input_text in input_texts]
  timer = gezi.Timer()
  texts_list, scores_list, atkeys2, atvals2, weights = sess.run([beam_text, beam_score, 
                                             tf.get_collection('attention_keys')[0],
                                             tf.get_collection('attention_values')[0],
                                             'seq2seq/main/decode/attention_keys/weights:0'
                                             ], 
                             feed_dict={predictor.input_text_feed: word_ids_list})
  
  print(atkeys2)
  print(atvals2[0])
  print(np.shape(atkeys2), np.shape(atvals2))

  print('beam_search using time(ms):', timer.elapsed_ms())

  weights = tf.constant(weights)
  values = tf.get_collection('attention_values')[0]

  values = tf.squeeze(values[0])
  zeros = tf.zeros_like(values[:5,:])

  values2 = tf.concat([values, zeros], 0)

  values = tf.cast(values, tf.float64)
  values2 = tf.cast(values2, tf.float64)
  weights = tf.cast(weights, tf.float64)
  result1 = tf.matmul(values, weights)
  result2 = tf.matmul(values2, weights)

  result1_, result2_ = sess.run([result1, result2], feed_dict={predictor.input_text_feed: [word_ids]})
  #result2 = sess.run(result, feed_dict={predictor.input_text_feed: word_ids_list})

  print(result1_)
  print(result2_)