예제 #1
0
def train_process(trainer, predictor=None):
    input_app = InputApp.InputApp()
    input_results = input_app.gen_input()

    with tf.variable_scope(FLAGS.main_scope) as scope:
        ops, gen_feed_dict, deal_results = gen_train(input_app, input_results,
                                                     trainer)

        scope.reuse_variables()

        #saving predict graph, so later can direclty predict without building from scratch
        #also used in gen validate if you want to use direclty predict as evaluate per epoch
        if predictor is not None and FLAGS.gen_predict:
            gen_predict_graph(predictor)

        eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate(
            input_app, input_results, trainer, predictor)

        metric_eval_fn = None
        if FLAGS.metric_eval:
            #generative can do this also but it is slow so just ingore this
            if not algos_factory.is_generative(FLAGS.algo):
                metric_eval_fn = lambda: evaluator.evaluate_scores(predictor,
                                                                   random=True)

    init_fn = None
    summary_excls = None
    if not FLAGS.pre_calc_image_feature:
        init_fn = melt.image.create_image_model_init_fn(
            FLAGS.image_model_name, FLAGS.image_checkpoint_file)

        if predictor is not None and FLAGS.gen_predict:
            #need to excl InceptionV3 summarys why inceptionV3 op might need image_feature_feed if gen_predict
            #gen_eval_feed_dict = lambda: {predictor.image_feature_feed: [melt.image.read_image(FLAGS.one_image)]}
            #gen_eval_feed_dict = lambda: {predictor.image_feature_feed: ['']}
            summary_excls = [FLAGS.image_model_name]

    melt.print_global_varaiables()
    melt.apps.train_flow(
        ops,
        gen_feed_dict_fn=gen_feed_dict,
        deal_results_fn=deal_results,
        eval_ops=eval_ops,
        gen_eval_feed_dict_fn=gen_eval_feed_dict,
        deal_eval_results_fn=deal_eval_results,
        optimizer=FLAGS.optimizer,
        learning_rate=FLAGS.learning_rate,
        num_steps_per_epoch=input_app.num_steps_per_epoch,
        model_dir=FLAGS.model_dir,
        metric_eval_fn=metric_eval_fn,
        summary_excls=summary_excls,
        init_fn=init_fn,
        sess=sess
    )  #notice if use melt.constant in predictor then must pass sess
예제 #2
0
def train():
    input_app = InputApp.InputApp()
    input_results = input_app.gen_input()

    with tf.variable_scope(FLAGS.main_scope) as scope:
        trainer, predictor = algos_factory.gen_trainer_and_predictor(
            FLAGS.algo)
        logging.info('trainer:{}'.format(trainer))
        logging.info('predictor:{}'.format(predictor))

        ops, gen_feed_dict, deal_results = gen_train(input_app, input_results,
                                                     trainer)
        scope.reuse_variables()
        algos_factory.set_eval_mode(trainer)

        if predictor is not None and FLAGS.gen_predict:
            beam_text, beam_text_score = gen_predict_graph(predictor)

        eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate(
            input_app, input_results, trainer, predictor)

        metric_eval_fn = None
        if FLAGS.metric_eval:
            if not algos_factory.is_generative(
                    FLAGS.algo) or FLAGS.assistant_model_dir:
                metric_eval_fn = lambda: evaluator.evaluate_scores(predictor,
                                                                   random=True)

    melt.print_global_varaiables()
    melt.apps.train_flow(
        ops,
        gen_feed_dict_fn=gen_feed_dict,
        deal_results_fn=deal_results,
        eval_ops=eval_ops,
        gen_eval_feed_dict_fn=gen_eval_feed_dict,
        deal_eval_results_fn=deal_eval_results,
        optimizer=FLAGS.optimizer,
        learning_rate=FLAGS.learning_rate,
        num_steps_per_epoch=input_app.num_steps_per_epoch,
        model_dir=FLAGS.model_dir,
        metric_eval_fn=metric_eval_fn,
        restore_scope=
        global_scope,  #only restore global scope as might evaluator has another predictor in graph with another scope name like dual_bow
        sess=sess
    )  #notice if use melt.constant in predictor then must pass sess
예제 #3
0
파일: train.py 프로젝트: buptpriswang/hasky
def train():
    input_app = InputApp.InputApp()
    input_results = input_app.gen_input()

    with tf.variable_scope(FLAGS.main_scope) as scope:
        trainer, predictor = algos_factory.gen_trainer_and_predictor(
            FLAGS.algo)
        logging.info('trainer:{}'.format(trainer))
        logging.info('predictor:{}'.format(predictor))

        algos_factory.set_eval_mode(trainer)
        ops, gen_feed_dict, deal_results = gen_train(input_app, input_results,
                                                     trainer)
        scope.reuse_variables()
        algos_factory.set_eval_mode(trainer)

        if predictor is not None and FLAGS.gen_predict:
            gen_predict_graph(predictor)

        eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate(
            input_app, input_results, trainer, predictor)

        metric_eval_fn = None
        if FLAGS.metric_eval:
            #generative can do this also but it is slow so just ingore this
            if not algos_factory.is_generative(FLAGS.algo):
                metric_eval_fn = lambda: evaluator.evaluate_scores(predictor,
                                                                   random=True)

    melt.print_global_varaiables()
    melt.apps.train_flow(
        ops,
        gen_feed_dict_fn=gen_feed_dict,
        deal_results_fn=deal_results,
        eval_ops=eval_ops,
        gen_eval_feed_dict_fn=gen_eval_feed_dict,
        deal_eval_results_fn=deal_eval_results,
        optimizer=FLAGS.optimizer,
        learning_rate=FLAGS.learning_rate,
        num_steps_per_epoch=input_app.num_steps_per_epoch,
        model_dir=FLAGS.model_dir,
        metric_eval_fn=metric_eval_fn,
        sess=sess
    )  #notice if use melt.constant in predictor then must pass sess
예제 #4
0
def train_process(trainer, predictor=None):
  input_app = InputApp.InputApp()
  input_results = input_app.gen_input()

  with tf.variable_scope(FLAGS.main_scope) as scope:
    ops, gen_feed_dict, deal_results = gen_train(
      input_app, 
      input_results, 
      trainer)
    scope.reuse_variables()

    if predictor is not None and FLAGS.gen_predict:
      beam_text, beam_text_score = gen_predict_graph(predictor)

    eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate(
      input_app, 
      input_results, 
      trainer, 
      predictor)

    metric_eval_fn = None
    if FLAGS.metric_eval:
      #generative can do this also but it is slow so just ingore this
      if not algos_factory.is_generative(FLAGS.algo): 
        metric_eval_fn = lambda: evaluator.evaluate_scores(predictor, random=True)

  if FLAGS.mode == 'train':
    melt.print_global_varaiables()
    melt.apps.train_flow(ops, 
                         gen_feed_dict_fn=gen_feed_dict,
                         deal_results_fn=deal_results,
                         eval_ops=eval_ops,
                         gen_eval_feed_dict_fn=gen_eval_feed_dict,
                         deal_eval_results_fn=deal_eval_results,
                         optimizer=FLAGS.optimizer,
                         learning_rate=FLAGS.learning_rate,
                         num_steps_per_epoch=input_app.num_steps_per_epoch,
                         model_dir=FLAGS.model_dir,
                         metric_eval_fn=metric_eval_fn,
                         sess=sess)#notice if use melt.constant in predictor then must pass sess
  else: #test predict
    predictor.load(FLAGS.model_dir)
    import conf  
    from conf import TEXT_MAX_WORDS, INPUT_TEXT_MAX_WORDS, NUM_RESERVED_IDS, ENCODE_UNK

    print('-------------------------', tf.get_collection('scores'))

    #TODO: now copy from prpare/gen-records.py
    def _text2ids(text, max_words):
      word_ids = text2ids.text2ids(text, 
                                   seg_method=FLAGS.seg_method, 
                                   feed_single=FLAGS.feed_single, 
                                   allow_all_zero=True, 
                                   pad=False)
      word_ids_length = len(word_ids)
      word_ids = word_ids[:max_words]
      word_ids = gezi.pad(word_ids, max_words, 0)
      return word_ids

    input_texts = [
                   #'包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网',
                   '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施',
                   ]

    for input_text in input_texts:
      word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS)
      print('word_ids', word_ids, 'len:', len(word_ids))
      print(text2ids.ids2text(word_ids))
      #similar as inference.py this is only ok for no attention mode TODO FIXME
      texts, scores = sess.run([tf.get_collection('text')[0], tf.get_collection('text_score')[0]], 
                             feed_dict={'seq2seq/model_init_1/input_text:0' : [word_ids]})
      print(texts[0], text2ids.ids2text(texts[0]), scores[0])

      texts, scores  = sess.run([beam_text, beam_text_score], 
                               feed_dict={predictor.input_text_feed: [word_ids]})

      texts = texts[0]
      scores = scores[0]
      for text, score in zip(texts, scores):
        print(text, text2ids.ids2text(text), score)
    
    input_texts = [
                   '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施',
                   #'包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网',
                   "宝宝太胖怎么办呢",
                   '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施',
                   #'大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施',
                   #'邹红建是阿拉斯加',
                   ]

    word_ids_list = [_text2ids(input_text, INPUT_TEXT_MAX_WORDS) for input_text in input_texts]
    timer = gezi.Timer()
    texts_list, scores_list = sess.run([beam_text, beam_text_score], 
                               feed_dict={predictor.input_text_feed: word_ids_list})
    
    for texts, scores in zip(texts_list, scores_list):
      for text, score in zip(texts, scores):
        print(text, text2ids.ids2text(text), score, math.log(score))

    print('beam_search using time(ms):', timer.elapsed_ms())