示例#1
0
def main(_):
  text2ids.init()
  global predictor
  predictor = melt.Predictor(FLAGS.model_dir)
  print(tf.get_default_graph().get_all_collection_keys())
  print(tf.get_collection('score'))
  run()
示例#2
0
def main(_):
  text2ids.init()
  predictor = melt.Predictor(FLAGS.model_dir)

  print(tf.get_collection('score'))
  #predict(predictor, '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���', '��˿�ڿ�Ů')
  #predict(predictor, '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���', '�Ը�����')
  #predict(predictor, '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���', '�Ը�Ů�ڿ�')
  #predict(predictor, '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���', 'ƻ������')
  #predict(predictor, '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���', '�Ը�͸���ڿ�')
  #predict(predictor, '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ', '�߲�')
  #predict(predictor, '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ', '����')
  #predict(predictor, '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ', '������ֲ')
  #predict(predictor, '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ', 'С����ͼƬ')
  #predict(predictor, '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ', '����')
  #predict(predictor, '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ', 'С����')
  #predict(predictor, '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ', '��������')
  #predict(predictor, '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ', '����С����')
  #predict(predictor, '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ', '������ʵ')
  #predict(predictor, '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ', 'С����')
  #predict(predictor, '�ڶ����������� ֱ�������ν��3��Ƭ��', '���ν��3')
  #predict(predictor, "ѧ���ٵ�����ʦ�� �ȶ��⾾ͷ����ͷ��ǽײ��3��סԺ", "Ů��")
  #predict(predictor, "ѧ���ٵ�����ʦ�� �ȶ��⾾ͷ����ͷ��ǽײ��3��סԺ", "ѧ��")
  #predict(predictor, "ѧ���ٵ�����ʦ�� �ȶ��⾾ͷ����ͷ��ǽײ��3��סԺ", "Ů��ѧ��")
  #predict(predictor, "ѧ���ٵ�����ʦ�� �ȶ��⾾ͷ����ͷ��ǽײ��3��סԺ", "Ů��ѧ��")

  #predict(predictor, '�ձ�ֱ��COSME���;���CANMAKE�۲���׸���Ӱ�׸�', 'canmake/���� ��Ӱ��')
  #predict(predictor, '16��HR�������������������۴�˪15ML��������˪�ֻ�', 'HR/������ �������������۴�˪')
  #predict(predictor, 'pony�Ƽ���������С�����ݰ�play101P���ݱʲ�ɫ˫ͷ�߹���Ӱ��', 'ETUDE HOUSE/����֮�� ˫ͷ�߹����ݰ�')
  
  predict(predictor, '�Ÿ�����ϴ���̻�ױƷ�ײ辻������ϴ��˪����ϴ�������', 'AGLAIA/�Ÿ����� �ײ辻������ϴ��˪')
  predict(predictor, '�����¹�����ֱ��ά�ٵ�/Weleda�β��ƽ�Ᵽʪ������˪30ml�л�����', 'Weleda/ά�ٵ� �β��ˮ��˪')
def main(_):
    #-----------init global resource
    melt.apps.train.init()

    FLAGS.vocab = FLAGS.vocab or os.path.join(os.path.dirname(FLAGS.model_dir),
                                              'vocab.txt')

    image_util.init()

    vocabulary.init()
    text2ids.init()

    ## TODO FIXME if evaluator init before main graph(assistant predictor with image model) then will wrong for finetune later,
    ## image scope as not defined, to set reuse = None? though assistant in different scope graph still scope reused??
    ## so right now just let evaluator lazy init, init when used after main graph build

    # try:
    #   evaluator.init()
    # except Exception:
    #   print(traceback.format_exc(), file=sys.stderr)
    #   print('evaluator init fail will not do metric eval')
    #   FLAGS.metric_eval = False

    logging.info('algo:{}'.format(FLAGS.algo))
    logging.info('monitor_level:{}'.format(FLAGS.monitor_level))

    global sess
    sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)

    train()
示例#4
0
def main(_):
  text2ids.init()
  predictor = melt.Predictor(FLAGS.model_dir)
  
  predict(predictor, "����̫����ô����")
  predict(predictor, "���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���")
  predict(predictor, "����������ʵ��С��ô��,����������ʵ��С���δ�ʩ")
示例#5
0
def main(_):
  text2ids.init()

  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam_search, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir) 

  #predict(predictor, "ʪĤ��ʪ��")
  #predict(predictor, "�����Ȱͺ��ű��ġ����������������ۣ������²ۣ�Ҫ������Ϊ������λ�����һ����")
  #predict(predictor, "��Ů١���")
  #predict(predictor, "����̫����ô����")
  #predict(predictor, "���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���")
  #predict(predictor, "��++��������ʵ��С��ô��,����������ʵ��С���δ�ʩ")

  predict(predictor, '2015�����й���������ǩ�������')

  while True:
    text = raw_input('')
    if text.startswith('quit'):
      break
    predict(predictor, text)
示例#6
0
def main(_):
  text2ids.init()
  predictor = melt.Predictor(FLAGS.model_dir)
  #predict(predictor, '包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网', '蕾丝内裤女')
  #predict(predictor, '包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网', '性感内衣')
  #predict(predictor, '包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网', '性感女内裤')
  #predict(predictor, '包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网', '苹果电脑')
  #predict(predictor, '包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网', '性感透明内裤')
  predict(predictor, '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', '蔬菜')
  predict(predictor, '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', '橘子')
  predict(predictor, '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', '辣椒种植')
  predict(predictor, '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', '小辣椒图片')
  predict(predictor, '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', '辣椒')
  predict(predictor, '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', '小辣椒')
  predict(predictor, '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', '辣椒辣椒')
  predict(predictor, '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', '辣椒小辣椒')
  predict(predictor, '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', '辣椒果实')
  predict(predictor, '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', '小橘子')
  predict(predictor, '众多名车齐上阵 直击《变形金刚3》片场', '变形金刚3')
  #predict(predictor, "学生迟到遭老师打 扇耳光揪头发把头往墙撞致3人住院", "女孩")
  #predict(predictor, "学生迟到遭老师打 扇耳光揪头发把头往墙撞致3人住院", "学生")
  #predict(predictor, "学生迟到遭老师打 扇耳光揪头发把头往墙撞致3人住院", "女生学生")
  #predict(predictor, "学生迟到遭老师打 扇耳光揪头发把头往墙撞致3人住院", "女生学术")

  predicts(predictor, 
           ['包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网', '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施'],
           ['蕾丝内裤女', '辣椒'])
示例#7
0
def init(image_model_=None):
  global inited

  if inited:
    return

  test_dir = FLAGS.valid_resource_dir
  global all_distinct_texts, all_distinct_text_strs
  global vocab, vocab_size
  if all_distinct_texts is None:
    print('loading valid resource from:', test_dir)
    #vocabulary.init()
    text2ids.init()
    vocab = vocabulary.vocab
    vocab_size = vocabulary.vocab_size
    
    if os.path.exists(test_dir + '/distinct_texts.npy'):
      all_distinct_texts = np.load(test_dir + '/distinct_texts.npy')
    else:
      all_distinct_texts = []
    
    #to avoid outof gpu mem
    #all_distinct_texts = all_distinct_texts[:FLAGS.max_texts]
    print('all_distinct_texts len:', len(all_distinct_texts), file=sys.stderr)
    
    #--padd it as test data set might be smaller in shape[1]
    all_distinct_texts = np.array([gezi.nppad(text, TEXT_MAX_WORDS) for text in all_distinct_texts])
    if FLAGS.feed_dict:
      all_distinct_texts = texts2ids(evaluator.all_distinct_text_strs)
    if os.path.exists(test_dir + '/distinct_text_strs.npy'):
      all_distinct_text_strs = np.load(test_dir + '/distinct_text_strs.npy')
    else:
      all_distinct_text_strs = []

    init_labels()

  #for evaluation without train will also use evaluator so only set log path in train.py
  #logging.set_logging_path(FLAGS.model_dir)
  if FLAGS.assistant_model_dir:
    global assistant_predictor
    #use another session different from main graph, otherwise variable will be destroy/re initailized in melt.flow
    #by default now Predictor using tf.Session already, here for safe, if use same session then not work
    #NOTICE must use seperate sess!!
    if is_raw_image(image_features) and not melt.varname_in_checkpoint(FLAGS.image_model_name, FLAGS.assistant_model_dir):
      print('assist predictor use deepiu.util.sim_predictor.SimPredictor as is raw image as input')
      global image_model 
      if image_model_ is not None:
        image_model = image_model_ 
      else:
        image_model = melt.image.ImageModel(FLAGS.image_checkpoint_file, 
                                            FLAGS.image_model_name, 
                                            feature_name=None)
      assistant_predictor = deepiu.util.sim_predictor.SimPredictor(FLAGS.assistant_model_dir, image_model=image_model)
      print('assistant predictor init ok')
    else:
      assistant_predictor = melt.SimPredictor(FLAGS.assistant_model_dir)
    print('assistant_predictor', assistant_predictor)

  inited = True
示例#8
0
def main(_):
  text2ids.init()
  global predictor
  predictor = melt.Predictor(FLAGS.model_dir)
  predictor2 = melt.Predictor('/home/gezi/new/temp/makeup/title2name/model/cnn.hic/')
  print(predictor, predictor2)
  print(tf.get_default_graph().get_all_collection_keys())
  print(tf.get_collection('score'))
示例#9
0
def main(_):
    logging.init(logtostderr=True, logtofile=False)
    global_scope = ''

    InputApp.init()
    vocabulary.init()
    text2ids.init()

    if FLAGS.add_global_scope:
        global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
    with tf.variable_scope(global_scope):
        test()
示例#10
0
def main(_):
  text2ids.init()
  predictor = melt.Predictor(FLAGS.model_dir, debug=FLAGS.debug)
  
  #predict(predictor, "�δﻪ�������»�Ů���� ��ͣ����̫̫(ͼ)")
  #predict(predictor, "������������_��������ǰ��Ա���Ƭ")
  ##predict(predictor, "��Сͨ�Ժ�������������ڼ� ��Ʒ������լ�в�")
  ##predict(predictor, "ѧ���ٵ�����ʦ�� �ȶ��⾾ͷ����ͷ��ǽײ��3��סԺ")
  ##predict(predictor, "����̫����ô����")
  ##predict(predictor, "���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���")
  ##predict(predictor, "����ף�Ŀǰ4����1�굶")
  ##predict(predictor, "����������ʵ��С��ô��,����������ʵ��С���δ�ʩ")
  ##predict(predictor, "����̫����ô����")
  #predict(predictor, "����������ʵ��С��ô��,����������ʵ��С���δ�ʩ")
  #predict(predictor, "����֮��(��):�ڶ���ħ���,�ڶ��������ι��� - �����")
  #predict(predictor, "�ڶ����������� ֱ�������ν��3��Ƭ��")

  predict(predictor, "���������뺫С����������׿���決�����մɷ���jy18")
示例#11
0
def main(_):
  text2ids.init()
  predictor = melt.Predictor(FLAGS.model_dir, debug=FLAGS.debug)
  
  #predict(predictor, "�δﻪ�������»�Ů���� ��ͣ����̫̫(ͼ)")
  #predict(predictor, "������������_��������ǰ��Ա���Ƭ")
  #predict(predictor, "��Сͨ�Ժ�������������ڼ� ��Ʒ������լ�в�")
  #predict(predictor, "ѧ���ٵ�����ʦ�� �ȶ��⾾ͷ����ͷ��ǽײ��3��סԺ")
  #predict(predictor, "����̫����ô����")
  #predict(predictor, "���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���")
  #predict(predictor, "����ף�Ŀǰ4����1�굶")
  #predict(predictor, "����������ʵ��С��ô��,����������ʵ��С���δ�ʩ")
  #predict(predictor, "����̫����ô����")
  predict(predictor, "����������ʵ��С��ô��,����������ʵ��С���δ�ʩ")

  predicts(predictor, [
    "���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���",
    "����������ʵ��С��ô��,����������ʵ��С���δ�ʩ",
  ])
示例#12
0
def main(_):
  text2ids.init()

  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam_search, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir) 

  predict(predictor, "����̫����ô����")
  predict(predictor, "���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���")
  predict(predictor, "����������ʵ��С��ô��,����������ʵ��С���δ�ʩ")
示例#13
0
def main(_):
    #-----------init global resource
    logging.set_logging_path(gezi.get_dir(FLAGS.model_dir))

    vocabulary.init()
    text2ids.init()

    #evaluator.init()

    logging.info('algo:{}'.format(FLAGS.algo))
    logging.info('monitor_level:{}'.format(FLAGS.monitor_level))

    global_scope = ''
    if FLAGS.add_global_scope:
        global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo

    global sess
    sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
    with tf.variable_scope(global_scope):
        train()
示例#14
0
def main(_):
    #-----------init global resource
    logging.set_logging_path(gezi.get_dir(FLAGS.model_dir))
    melt.apps.train.init()

    has_image_model = FLAGS.image_checkpoint_file and os.path.exists(
        FLAGS.image_checkpoint_file)
    if has_image_model:
        print('image_endpoint_feature_name:',
              FLAGS.image_endpoint_feature_name)
        melt.apps.image_processing.init(
            FLAGS.image_model_name,
            feature_name=FLAGS.image_endpoint_feature_name)

    FLAGS.pre_calc_image_feature = FLAGS.pre_calc_image_feature or (
        not has_image_model)

    vocabulary.init()
    text2ids.init()

    ## TODO FIXME if evaluator init before main graph(assistant preidictor with image model) then will wrong for finetune later,
    ## image scope as not defined, to set reuse = None? though assistant in different scope graph still scope reused??
    ## so right now just let evaluator lazy init, init when used after main graph build

    # try:
    #   evaluator.init()
    # except Exception:
    #   print(traceback.format_exc(), file=sys.stderr)
    #   print('evaluator init fail will not do metric eval')
    #   FLAGS.metric_eval = False

    logging.info('algo:{}'.format(FLAGS.algo))
    logging.info('monitor_level:{}'.format(FLAGS.monitor_level))

    global sess
    sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)

    global_scope = melt.apps.train.get_global_scope()
    with tf.variable_scope(global_scope):
        train()
示例#15
0
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  global sess
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      text, score, beam_text, beam_score = gen_predict_graph(predictor, scope)

  predictor.load(FLAGS.model_dir) 
  #input_text = "������������_��������ǰ��Ա���Ƭ"
  input_texts = ['���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 '����̫����ô����',
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 '����ף�Ŀǰ4����1�굶']

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    timer = gezi.Timer()
    text_, score_ = sess.run([text, score], {predictor.input_text_place : [word_ids]})
    print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    texts, scores = sess.run([beam_text, beam_score], {predictor.input_text_place : [word_ids]})

    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_)

    print('beam_search using time(ms):', timer.elapsed_ms())
示例#16
0
def main(_):
    #-----------init global resource
    logging.set_logging_path(gezi.get_dir(FLAGS.model_dir))

    if not FLAGS.pre_calc_image_feature:
        melt.apps.image_processing.init()

    InputApp.init()
    vocabulary.init()
    text2ids.init()
    evaluator.init()

    logging.info('algo:{}'.format(FLAGS.algo))
    logging.info('monitor_level:{}'.format(FLAGS.monitor_level))

    global sess
    sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)

    global_scope = ''
    if FLAGS.add_global_scope:
        global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
    with tf.variable_scope(global_scope):
        train()
示例#17
0
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      ##--notice if not add below len(tf.get_collection('encode_state') is 1, add below will be 2
      ## even though in init_predict_text(decode_method=SeqDecodeMethod.beam) will call generate_sequence_greedy
      #text, score = predictor.init_predict_text(decode_method=SeqDecodeMethod.greedy, 
      #                                          beam_size=FLAGS.beam_size,
      #                                          convert_unk=False)   
      #scope.reuse_variables()
      #score = predictor.init_predict(exact_loss=True)
      #score = predictor.init_predict(exact_prob=True)
      score = predictor.init_predict()
      scope.reuse_variables()
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir, sess=sess) 

  for item in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(item)
  #input_text = "王凯整容了吗_王凯整容前后对比照片"
  input_texts = [
                 #'包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网',
                 #'大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施',
                 #'宝宝太胖怎么办呢',
                 #'蛋龟缸,目前4虎纹1剃刀',
                 #'大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施',
                 #'2015羊年中国风年会晚会签到板设计',
                 '完美 玛丽艳脱角质霜'
                 ]

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    #timer = gezi.Timer()
    #text_, score_ = sess.run([text, score], {predictor.input_text_feed : [word_ids]})
    #print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    #texts, scores, preids, paids, seqlens = sess.run([beam_text, beam_score, 
    #                         tf.get_collection('preids')[-1], 
    #                         tf.get_collection('paids')[-1],
    #                         tf.get_collection('seqlens')[-1]],
    #                                        {predictor.input_text_feed : [word_ids]})

    #print(preids)
    #print(paids)
    #print(seqlens)

    score = sess.run(score, {predictor.input_text_feed: [word_ids], predictor.text_feed: [word_ids]})
    print(score)

    texts, scores = sess.run([beam_text, beam_score],
                                            {predictor.input_text_feed : [word_ids]})

    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_, math.log(score_))

    print('beam_search using time(ms):', timer.elapsed_ms())
示例#18
0
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      ##--notice if not add below len(tf.get_collection('encode_state') is 1, add below will be 2
      ## even though in init_predict_text(decode_method=SeqDecodeMethod.beam) will call generate_sequence_greedy
      #text, score = predictor.init_predict_text(decode_method=SeqDecodeMethod.greedy, 
      #                                          beam_size=FLAGS.beam_size,
      #                                          convert_unk=False)   
      #scope.reuse_variables()
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir, sess=sess) 

  for item in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(item)
  #input_text = "������������_��������ǰ��Ա���Ƭ"
  input_texts = [
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����̫����ô����',
                 #'����ף�Ŀǰ4����1�굶',
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 ]

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    #timer = gezi.Timer()
    #text_, score_ = sess.run([text, score], {predictor.input_text_feed : [word_ids]})
    #print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    print(tf.get_collection('encode_state'), len(tf.get_collection('encode_state')))
    texts, scores,  atkeys, atvals  = sess.run([beam_text, beam_score, 
                                             tf.get_collection('attention_keys')[0],
                                             tf.get_collection('attention_values')[0],
                                             ], 
                                            {predictor.input_text_feed : [word_ids]})

    print(atkeys)
    print(atvals)
    print(np.shape(atkeys), np.shape(atvals))

    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_)

    print('beam_search using time(ms):', timer.elapsed_ms())

  input_texts = [
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ����', #same length as lajiao sentence 15
                 #"����̫����ô����",
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'�޺콨�ǰ���˹��',
                 ]

  word_ids_list = [_text2ids(input_text, INPUT_TEXT_MAX_WORDS) for input_text in input_texts]
  timer = gezi.Timer()
  texts_list, scores_list, atkeys2, atvals2, weights = sess.run([beam_text, beam_score, 
                                             tf.get_collection('attention_keys')[0],
                                             tf.get_collection('attention_values')[0],
                                             'seq2seq/main/decode/attention_keys/weights:0'
                                             ], 
                             feed_dict={predictor.input_text_feed: word_ids_list})
  
  print(atkeys2)
  print(atvals2[0])
  print(np.shape(atkeys2), np.shape(atvals2))

  print('beam_search using time(ms):', timer.elapsed_ms())

  weights = tf.constant(weights)
  values = tf.get_collection('attention_values')[0]

  values = tf.squeeze(values[0])
  zeros = tf.zeros_like(values[:5,:])

  values2 = tf.concat([values, zeros], 0)

  values = tf.cast(values, tf.float64)
  values2 = tf.cast(values2, tf.float64)
  weights = tf.cast(weights, tf.float64)
  result1 = tf.matmul(values, weights)
  result2 = tf.matmul(values2, weights)

  result1_, result2_ = sess.run([result1, result2], feed_dict={predictor.input_text_feed: [word_ids]})
  #result2 = sess.run(result, feed_dict={predictor.input_text_feed: word_ids_list})

  print(result1_)
  print(result2_)
示例#19
0
 def __init__(self, model_dir, vocab_path, key=None, index=0, sess=None):
     self.predictor = melt.WordsImportancePredictor(model_dir,
                                                    key=key,
                                                    index=index)
     text2ids.init(vocab_path)
示例#20
0
def main(_):
  text2ids.init()
  predictor = melt.Predictor(FLAGS.model_dir, debug=FLAGS.debug)

  input_texts = [x.strip() for x in open(FLAGS.text_file)]
  predicts(predictor, input_texts)
示例#21
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys, os,glob
import tensorflow as tf
import melt 


from libword_counter import Vocabulary

from deepiu.util import text2ids

dir = '/home/gezi/new/temp/makeup/title2name/tfrecord/seq-basic/'

text2ids.init(os.path.join(dir, 'vocab.txt'))
vocab = text2ids.vocab

embsim = melt.EmbeddingSim(os.path.join(dir, 'word2vec'), name='w_in')

corpus_pattern = os.path.join('/home/gezi/data/product/makeup/tb/title2name/valid/*')

max_words = 50
#itexts = ['ÑÅÊ«À¼÷ìË®Èó˪', 'ÑÅÊ«À¼÷ìС×ØÆ¿', 'ÑÅÊ«À¼÷ìºìʯÁñ', 'æÃÃÀ¿óÎïȪ²¹Ë®¾«»ª', 'Adidas°¢µÏ´ï˹ÄÐÊ¿ÏãË®ÄÐÊ¿¹ÅÁúµ­ÏãË® ±ùµãÄÐÏã100ml¡¾¾©¶«³¬ÊС¿']

itexts = ['ÑÅÊ«À¼÷ìanrÐÞ»¤¼¡Í¸¾«»ªÂ¶']

left_ids = [text2ids.text2ids(x, seg_method='basic', feed_single=True, max_words=max_words) for x in itexts]


lids_ = tf.placeholder(dtype=tf.int32, shape=[None, max_words]) 
示例#22
0
def main(_):
  text2ids.init()
  global predictor
  predictor = melt.RerankSimPredictor(FLAGS.model_dir, FLAGS.exact_model_dir)

  run()
示例#23
0
flags.DEFINE_string(
    'vocab',
    '/home/gezi/new/temp/image-caption/lijiaoshou/tfrecord/seq-basic/vocab.txt',
    'vocabulary file')
flags.DEFINE_string(
    'model_dir',
    '/home/gezi/new/temp/image-caption/lijiaoshou/model/rnn.max.gru.bi/', '')
flags.DEFINE_string('seg_method_', 'basic', '')

import gezi
import melt
from deepiu.util import text2ids

import numpy as np

text2ids.init(FLAGS.vocab)

predictor = melt.Predictor(FLAGS.model_dir)


def predict(text):
    timer = gezi.Timer()
    text_ids = text2ids.text2ids(text, FLAGS.seg_method_, feed_single=True)
    print('text_ids', text_ids)

    #seq_len = 50

    #print('words', words)
    argmax_encode = predictor.inference(
        ['text_importance'], feed_dict={'rnn/main/text:0': [text_ids]})
    print('argmax_encode', argmax_encode[0])
示例#24
0
def main(_):
  text2ids.init()
  read_records()
示例#25
0
print('ENCODE_UNK', ENCODE_UNK, file=sys.stderr)
assert ENCODE_UNK == text2ids.ENCODE_UNK

ltexts = []
ltext_strs = []
rtexts = []
rtext_strs = []

#how many records generated
counter = Value('i', 0)
#the max num words of the longest text
max_num_words = Value('i', 0)
#the total words of all text
sum_words = Value('i', 0)

text2ids.init()


def _text2ids(text, max_words):
    word_ids = text2ids.text2ids(text,
                                 seg_method=FLAGS.seg_method,
                                 feed_single=FLAGS.feed_single,
                                 allow_all_zero=True,
                                 pad=False)
    word_ids_length = len(word_ids)

    if len(word_ids) == 0:
        return []
    word_ids = word_ids[:max_words]
    if FLAGS.pad:
        word_ids = gezi.pad(word_ids, max_words, 0)
示例#26
0
文件: predict.py 项目: fword/hasky
def main(_):
  text2ids.init()
  global_scope = ''
  if FLAGS.add_global_scope:
    global_scope = FLAGS.global_scope if FLAGS.global_scope else FLAGS.algo
 
  sess = melt.get_session(log_device_placement=FLAGS.log_device_placement)
  with tf.variable_scope(global_scope):
    predictor =  algos_factory.gen_predictor(FLAGS.algo)
    with tf.variable_scope(FLAGS.main_scope) as scope:
      ##--notice if not add below len(tf.get_collection('encode_state') is 1, add below will be 2
      ## even though in init_predict_text(decode_method=SeqDecodeMethod.beam) will call generate_sequence_greedy
      #text, score = predictor.init_predict_text(decode_method=SeqDecodeMethod.greedy, 
      #                                          beam_size=FLAGS.beam_size,
      #                                          convert_unk=False)   
      #scope.reuse_variables()
      beam_text, beam_score = predictor.init_predict_text(decode_method=SeqDecodeMethod.beam, 
                                                          beam_size=FLAGS.beam_size,
                                                          convert_unk=False)  

  predictor.load(FLAGS.model_dir, sess=sess) 

  for item in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(item)
  #input_text = "������������_��������ǰ��Ա���Ƭ"
  input_texts = [
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����̫����ô����',
                 '����ף�Ŀǰ4����1�굶',
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 ]

  for input_text in input_texts:
    word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)

    print(word_ids)
    print(text2ids.ids2text(word_ids))

    #timer = gezi.Timer()
    #text_, score_ = sess.run([text, score], {predictor.input_text_feed : [word_ids]})
    #print(text_[0], text2ids.ids2text(text_[0]), score_[0], 'time(ms):', timer.elapsed_ms())

    timer = gezi.Timer()
    texts, scores = sess.run([beam_text, beam_score], 
                                            {predictor.input_text_feed : [word_ids]})


    texts = texts[0]
    scores = scores[0]
    for text_, score_ in zip(texts, scores):
      print(text_, text2ids.ids2text(text_), score_, math.log(score_))

    print('beam_search using time(ms):', timer.elapsed_ms())

  input_texts = [
                 '����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 '���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ���������ڿ�Ů��-�Ա���',
                 #'���������һ�Ը�Ů�ڿ�����ջ�͸����˿¶�δ����', #same length as lajiao sentence 15
                 #"����̫����ô����",
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'����������ʵ��С��ô��,����������ʵ��С���δ�ʩ',
                 #'�޺콨�ǰ���˹��',
                 ]

  word_ids_list = [_text2ids(input_text, INPUT_TEXT_MAX_WORDS) for input_text in input_texts]
  timer = gezi.Timer()
  texts_list, scores_list = sess.run([beam_text, beam_score], 
                             feed_dict={predictor.input_text_feed: word_ids_list})
  

  for texts, scores in zip(texts_list, scores_list):
    for text, score in zip(texts, scores):
      print(text, text2ids.ids2text(text), score, math.log(score))

  print('beam_search using time(ms):', timer.elapsed_ms())
def main(_):
  text2ids.init()
  global predictor, image_model
  predictor = melt.Predictor(FLAGS.model_dir)

  run()
示例#28
0
vocab_path = '/home/gezi/new/temp/image-caption/ai-challenger/tfrecord/seq-basic/vocab.txt'
valid_dir = '/home/gezi/new/temp/image-caption/ai-challenger/tfrecord/seq-basic/valid'

image_model_name = 'InceptionResnetV2'

feature_name = melt.image.get_features_name(image_model_name)

#if finetuned model, just  TextPredictor(model_dir, vocab_path)
if not melt.varname_in_checkpoint(image_model_name, model_dir):
  predictor = TextPredictor(model_dir, vocab_path, image_model_checkpoint_path, image_model_name=image_model_name, feature_name=feature_name)
else:
  predictor = TextPredictor(model_dir, vocab_path)
  
vocab = ids2text.vocab 

text2ids.init(vocab_path)

sim_predictor = SimPredictor(sim_model_dir, image_model_checkpoint_path, image_model_name=image_model_name, index=-1)

text_strs = np.load(os.path.join(valid_dir, 'distinct_text_strs.npy'))
img2text = np.load(os.path.join(valid_dir, 'img2text.npy')).item()

while True:
  image_file = raw_input('image_file like 6275b5349168ac3fab6a493c509301d023cf39d3.jpg:')
  if image_file == 'q':
    break

  image_path = os.path.join(image_dir, image_file)
  print('image_path:', image_path)

  if not os.path.exists(image_path):