def train_process(trainer, predictor=None): input_app = InputApp.InputApp() input_results = input_app.gen_input() with tf.variable_scope(FLAGS.main_scope) as scope: ops, gen_feed_dict, deal_results = gen_train(input_app, input_results, trainer) scope.reuse_variables() #saving predict graph, so later can direclty predict without building from scratch #also used in gen validate if you want to use direclty predict as evaluate per epoch if predictor is not None and FLAGS.gen_predict: gen_predict_graph(predictor) eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate( input_app, input_results, trainer, predictor) metric_eval_fn = None if FLAGS.metric_eval: #generative can do this also but it is slow so just ingore this if not algos_factory.is_generative(FLAGS.algo): metric_eval_fn = lambda: evaluator.evaluate_scores(predictor, random=True) init_fn = None summary_excls = None if not FLAGS.pre_calc_image_feature: init_fn = melt.image.create_image_model_init_fn( FLAGS.image_model_name, FLAGS.image_checkpoint_file) if predictor is not None and FLAGS.gen_predict: #need to excl InceptionV3 summarys why inceptionV3 op might need image_feature_feed if gen_predict #gen_eval_feed_dict = lambda: {predictor.image_feature_feed: [melt.image.read_image(FLAGS.one_image)]} #gen_eval_feed_dict = lambda: {predictor.image_feature_feed: ['']} summary_excls = [FLAGS.image_model_name] melt.print_global_varaiables() melt.apps.train_flow( ops, gen_feed_dict_fn=gen_feed_dict, deal_results_fn=deal_results, eval_ops=eval_ops, gen_eval_feed_dict_fn=gen_eval_feed_dict, deal_eval_results_fn=deal_eval_results, optimizer=FLAGS.optimizer, learning_rate=FLAGS.learning_rate, num_steps_per_epoch=input_app.num_steps_per_epoch, model_dir=FLAGS.model_dir, metric_eval_fn=metric_eval_fn, summary_excls=summary_excls, init_fn=init_fn, sess=sess ) #notice if use melt.constant in predictor then must pass sess
def train(): input_app = InputApp.InputApp() input_results = input_app.gen_input() with tf.variable_scope(FLAGS.main_scope) as scope: trainer, predictor = algos_factory.gen_trainer_and_predictor( FLAGS.algo) logging.info('trainer:{}'.format(trainer)) logging.info('predictor:{}'.format(predictor)) ops, gen_feed_dict, deal_results = gen_train(input_app, input_results, trainer) scope.reuse_variables() algos_factory.set_eval_mode(trainer) if predictor is not None and FLAGS.gen_predict: beam_text, beam_text_score = gen_predict_graph(predictor) eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate( input_app, input_results, trainer, predictor) metric_eval_fn = None if FLAGS.metric_eval: if not algos_factory.is_generative( FLAGS.algo) or FLAGS.assistant_model_dir: metric_eval_fn = lambda: evaluator.evaluate_scores(predictor, random=True) melt.print_global_varaiables() melt.apps.train_flow( ops, gen_feed_dict_fn=gen_feed_dict, deal_results_fn=deal_results, eval_ops=eval_ops, gen_eval_feed_dict_fn=gen_eval_feed_dict, deal_eval_results_fn=deal_eval_results, optimizer=FLAGS.optimizer, learning_rate=FLAGS.learning_rate, num_steps_per_epoch=input_app.num_steps_per_epoch, model_dir=FLAGS.model_dir, metric_eval_fn=metric_eval_fn, restore_scope= global_scope, #only restore global scope as might evaluator has another predictor in graph with another scope name like dual_bow sess=sess ) #notice if use melt.constant in predictor then must pass sess
def train(): input_app = InputApp.InputApp() input_results = input_app.gen_input() with tf.variable_scope(FLAGS.main_scope) as scope: trainer, predictor = algos_factory.gen_trainer_and_predictor( FLAGS.algo) logging.info('trainer:{}'.format(trainer)) logging.info('predictor:{}'.format(predictor)) algos_factory.set_eval_mode(trainer) ops, gen_feed_dict, deal_results = gen_train(input_app, input_results, trainer) scope.reuse_variables() algos_factory.set_eval_mode(trainer) if predictor is not None and FLAGS.gen_predict: gen_predict_graph(predictor) eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate( input_app, input_results, trainer, predictor) metric_eval_fn = None if FLAGS.metric_eval: #generative can do this also but it is slow so just ingore this if not algos_factory.is_generative(FLAGS.algo): metric_eval_fn = lambda: evaluator.evaluate_scores(predictor, random=True) melt.print_global_varaiables() melt.apps.train_flow( ops, gen_feed_dict_fn=gen_feed_dict, deal_results_fn=deal_results, eval_ops=eval_ops, gen_eval_feed_dict_fn=gen_eval_feed_dict, deal_eval_results_fn=deal_eval_results, optimizer=FLAGS.optimizer, learning_rate=FLAGS.learning_rate, num_steps_per_epoch=input_app.num_steps_per_epoch, model_dir=FLAGS.model_dir, metric_eval_fn=metric_eval_fn, sess=sess ) #notice if use melt.constant in predictor then must pass sess
def train_process(trainer, predictor=None): input_app = InputApp.InputApp() input_results = input_app.gen_input() with tf.variable_scope(FLAGS.main_scope) as scope: ops, gen_feed_dict, deal_results = gen_train( input_app, input_results, trainer) scope.reuse_variables() if predictor is not None and FLAGS.gen_predict: beam_text, beam_text_score = gen_predict_graph(predictor) eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate( input_app, input_results, trainer, predictor) metric_eval_fn = None if FLAGS.metric_eval: #generative can do this also but it is slow so just ingore this if not algos_factory.is_generative(FLAGS.algo): metric_eval_fn = lambda: evaluator.evaluate_scores(predictor, random=True) if FLAGS.mode == 'train': melt.print_global_varaiables() melt.apps.train_flow(ops, gen_feed_dict_fn=gen_feed_dict, deal_results_fn=deal_results, eval_ops=eval_ops, gen_eval_feed_dict_fn=gen_eval_feed_dict, deal_eval_results_fn=deal_eval_results, optimizer=FLAGS.optimizer, learning_rate=FLAGS.learning_rate, num_steps_per_epoch=input_app.num_steps_per_epoch, model_dir=FLAGS.model_dir, metric_eval_fn=metric_eval_fn, sess=sess)#notice if use melt.constant in predictor then must pass sess else: #test predict predictor.load(FLAGS.model_dir) import conf from conf import TEXT_MAX_WORDS, INPUT_TEXT_MAX_WORDS, NUM_RESERVED_IDS, ENCODE_UNK print('-------------------------', tf.get_collection('scores')) #TODO: now copy from prpare/gen-records.py def _text2ids(text, max_words): word_ids = text2ids.text2ids(text, seg_method=FLAGS.seg_method, feed_single=FLAGS.feed_single, allow_all_zero=True, pad=False) word_ids_length = len(word_ids) word_ids = word_ids[:max_words] word_ids = gezi.pad(word_ids, max_words, 0) return word_ids input_texts = [ #'包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网', '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', ] for input_text in input_texts: word_ids = _text2ids(input_text, INPUT_TEXT_MAX_WORDS) print('word_ids', word_ids, 'len:', len(word_ids)) print(text2ids.ids2text(word_ids)) #similar as inference.py this is only ok for no attention mode TODO FIXME texts, scores = sess.run([tf.get_collection('text')[0], tf.get_collection('text_score')[0]], feed_dict={'seq2seq/model_init_1/input_text:0' : [word_ids]}) print(texts[0], text2ids.ids2text(texts[0]), scores[0]) texts, scores = sess.run([beam_text, beam_text_score], feed_dict={predictor.input_text_feed: [word_ids]}) texts = texts[0] scores = scores[0] for text, score in zip(texts, scores): print(text, text2ids.ids2text(text), score) input_texts = [ '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', #'包邮买二送一性感女内裤低腰诱惑透视蕾丝露臀大蝴蝶三角内裤女夏-淘宝网', "宝宝太胖怎么办呢", '大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', #'大棚辣椒果实变小怎么办,大棚辣椒果实变小防治措施', #'邹红建是阿拉斯加', ] word_ids_list = [_text2ids(input_text, INPUT_TEXT_MAX_WORDS) for input_text in input_texts] timer = gezi.Timer() texts_list, scores_list = sess.run([beam_text, beam_text_score], feed_dict={predictor.input_text_feed: word_ids_list}) for texts, scores in zip(texts_list, scores_list): for text, score in zip(texts, scores): print(text, text2ids.ids2text(text), score, math.log(score)) print('beam_search using time(ms):', timer.elapsed_ms())