def tower_loss(trainer, input_app=None, input_results=None): if input_app is None: input_app = InputApp.InputApp() if input_results is None: input_results = input_app.gen_input(train_only=True) #--------train weights = None if not FLAGS.use_weights: image_name, image_feature, text, text_str = input_results[ input_app.input_train_name] else: image_name, image_feature, text, text_str, weights = input_results[ input_app.input_train_name] global gtrain_image_name #for rl gtrain_image_name = image_name with tf.device('/gpu:0'): main_loss = trainer.build_train_graph(image_feature, text) weights = None if not FLAGS.use_weights: image_name, image, label, text_str = input_results[ input_app.scene_input_train_name] else: image_name, image, label, text_str, weights = input_results[ input_app.scene_input_train_name] with tf.device('/gpu:1'): scene_loss = trainer.build_scene_graph(image, label) loss = main_loss + 0.1 * scene_loss return loss, main_loss, scene_loss
def train_process(trainer, predictor=None): input_app = InputApp.InputApp() input_results = input_app.gen_input() with tf.variable_scope(FLAGS.main_scope) as scope: ops, gen_feed_dict, deal_results = gen_train(input_app, input_results, trainer) scope.reuse_variables() #saving predict graph, so later can direclty predict without building from scratch #also used in gen validate if you want to use direclty predict as evaluate per epoch if predictor is not None and FLAGS.gen_predict: gen_predict_graph(predictor) eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate( input_app, input_results, trainer, predictor) metric_eval_fn = None if FLAGS.metric_eval: #generative can do this also but it is slow so just ingore this if not algos_factory.is_generative(FLAGS.algo): metric_eval_fn = lambda: evaluator.evaluate_scores(predictor, random=True) init_fn = None summary_excls = None if not FLAGS.pre_calc_image_feature: init_fn = melt.image.create_image_model_init_fn( FLAGS.image_model_name, FLAGS.image_checkpoint_file) if predictor is not None and FLAGS.gen_predict: #need to excl InceptionV3 summarys why inceptionV3 op might need image_feature_feed if gen_predict #gen_eval_feed_dict = lambda: {predictor.image_feature_feed: [melt.image.read_image(FLAGS.one_image)]} #gen_eval_feed_dict = lambda: {predictor.image_feature_feed: ['']} summary_excls = [FLAGS.image_model_name] melt.print_global_varaiables() melt.apps.train_flow( ops, gen_feed_dict_fn=gen_feed_dict, deal_results_fn=deal_results, eval_ops=eval_ops, gen_eval_feed_dict_fn=gen_eval_feed_dict, deal_eval_results_fn=deal_eval_results, optimizer=FLAGS.optimizer, learning_rate=FLAGS.learning_rate, num_steps_per_epoch=input_app.num_steps_per_epoch, model_dir=FLAGS.model_dir, metric_eval_fn=metric_eval_fn, summary_excls=summary_excls, init_fn=init_fn, sess=sess ) #notice if use melt.constant in predictor then must pass sess
def tower_loss(trainer, input_app=None, input_results=None): if input_app is None: input_app = InputApp.InputApp() if input_results is None: input_results = input_app.gen_input(train_only=True) #--------train image_name, image_feature, text, text_str = input_results[ input_app.input_train_name] #--------train neg if input_app.input_train_neg_name in input_results: neg_text, neg_text_str = input_results[input_app.input_train_neg_name] else: neg_text, neg_text_str = None, None loss = trainer.build_train_graph(image_feature, text, neg_text) return loss
def tower_loss(trainer, input_app=None, input_results=None): if input_app is None: input_app = InputApp.InputApp() if input_results is None: input_results = input_app.gen_input(train_only=True) #--------train image_name, image_feature, text, text_str = input_results[ input_app.input_train_name] #--------train neg if input_results[input_app.input_train_neg_name]: neg_image_name, neg_image_feature, neg_text, neg_text_str = input_results[ input_app.input_train_neg_name] if not FLAGS.neg_left: neg_image_feature = None if not FLAGS.neg_right: neg_text = None loss = trainer.build_train_graph(image_feature, text, neg_image_feature, neg_text) return loss
def train_process(trainer, predictor=None): input_app = InputApp.InputApp() input_results = input_app.gen_input() with tf.variable_scope('run') as scope: ops, gen_feed_dict, deal_results = gen_train(input_app, input_results, trainer) scope.reuse_variables() #saving predict graph, so later can direclty predict without building from scratch #also used in gen validate if you want to use direclty predict as evaluate per epoch if predictor is not None: gen_predict_graph(predictor) eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate( input_app, input_results, trainer, predictor) metric_eval_function = None if FLAGS.metric_eval: #generative can do this also but it is slow so just ingore this if not algos_factory.is_generative(FLAGS.algo): metric_eval_function = lambda: evaluator.evaluate_scores( predictor, random=True) melt.apps.train_flow( ops, gen_feed_dict=gen_feed_dict, deal_results=deal_results, eval_ops=eval_ops, gen_eval_feed_dict=gen_eval_feed_dict, deal_eval_results=deal_eval_results, optimizer=FLAGS.optimizer, learning_rate=FLAGS.learning_rate, num_steps_per_epoch=input_app.num_steps_per_epoch, model_dir=FLAGS.model_dir, metric_eval_function=metric_eval_function, sess=sess ) #notice if use melt.constant in predictor then must pass sess
def train(): input_app = InputApp.InputApp() input_results = input_app.gen_input() with tf.variable_scope(FLAGS.main_scope) as scope: trainer, predictor = algos_factory.gen_trainer_and_predictor( FLAGS.algo) logging.info('trainer:{}'.format(trainer)) logging.info('predictor:{}'.format(predictor)) ops, gen_feed_dict, deal_results = gen_train(input_app, input_results, trainer) scope.reuse_variables() algos_factory.set_eval_mode(trainer) #saving predict graph, so later can direclty predict without building from scratch #also used in gen validate if you want to use direclty predict as evaluate per epoch if predictor is not None and FLAGS.gen_predict: gen_predict_graph(predictor) eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate( input_app, input_results, trainer, predictor) metric_eval_fn = None if FLAGS.metric_eval: eval_rank = FLAGS.eval_rank and (not algos_factory.is_generative( FLAGS.algo) or FLAGS.assistant_model_dir) eval_translation = FLAGS.eval_translation and algos_factory.is_generative( FLAGS.algo) metric_eval_fn = lambda: evaluator.evaluate(predictor, random=True, eval_rank=eval_rank, eval_translation= eval_translation) init_fn = None restore_fn = None summary_excls = None if not FLAGS.pre_calc_image_feature: init_fn = melt.image.image_processing.create_image_model_init_fn( FLAGS.image_model_name, FLAGS.image_checkpoint_file) if melt.checkpoint_exists_in(FLAGS.model_dir): if not melt.varname_in_checkpoint(FLAGS.image_model_name, FLAGS.model_dir): restore_fn = init_fn #melt.print_global_varaiables() melt.apps.train_flow( ops, gen_feed_dict_fn=gen_feed_dict, deal_results_fn=deal_results, eval_ops=eval_ops, gen_eval_feed_dict_fn=gen_eval_feed_dict, deal_eval_results_fn=deal_eval_results, optimizer=FLAGS.optimizer, learning_rate=FLAGS.learning_rate, num_steps_per_epoch=input_app.num_steps_per_epoch, model_dir=FLAGS.model_dir, metric_eval_fn=metric_eval_fn, summary_excls=summary_excls, init_fn=init_fn, restore_fn=restore_fn, sess=sess ) #notice if use melt.constant in predictor then must pass sess
def train(): input_app = InputApp.InputApp() input_results = input_app.gen_input() global_scope = melt.apps.train.get_global_scope() with tf.variable_scope(global_scope) as global_scope: with tf.variable_scope(FLAGS.main_scope) as scope: trainer, validator, predictor = algos_factory.gen_all(FLAGS.algo) if FLAGS.reinforcement_learning: from deepiu.image_caption.algos.show_and_tell import RLInfo trainer.rl = RLInfo() if FLAGS.scene_model: #TODO scene model not support show_eval right now due to batch size eg 3 for show_eval from deepiu.image_caption.algos.show_and_tell import SceneInfo trainer.scene = SceneInfo(FLAGS.batch_size) validator.scene = SceneInfo(FLAGS.eval_batch_size) predictor.scene = SceneInfo() logging.info('trainer:{}'.format(trainer)) logging.info('predictor:{}'.format(predictor)) ops, gen_feed_dict, deal_results = gen_train( input_app, input_results, trainer) scope.reuse_variables() #saving predict graph, so later can direclty predict without building from scratch #also used in gen validate if you want to use direclty predict as evaluate per epoch if predictor is not None and FLAGS.gen_predict: gen_predict_graph(predictor) eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate( input_app, input_results, validator, predictor) metric_eval_fn = None if FLAGS.metric_eval: if FLAGS.scene_model: evaluator.gen_feed_dict_fn = lambda image_features: gen_predict_scene_feed_dict( predictor, image_features) eval_rank = FLAGS.eval_rank and ( not algos_factory.is_generative(FLAGS.algo) or FLAGS.assistant_model_dir) eval_translation = FLAGS.eval_translation and algos_factory.is_generative( FLAGS.algo) metric_eval_fn = lambda: evaluator.evaluate( predictor, random=True, eval_rank=eval_rank, eval_translation=eval_translation) # NOTCIE in empty scope now, image model need to escape all scopes! summary_excls = None init_fn, restore_fn = image_util.get_init_restore_fn() with tf.variable_scope(global_scope): #melt.print_global_varaiables() melt.apps.train_flow( ops, names=names, gen_feed_dict_fn=gen_feed_dict, deal_results_fn=deal_results, eval_ops=eval_ops, eval_names=eval_names, gen_eval_feed_dict_fn=gen_eval_feed_dict, deal_eval_results_fn=deal_eval_results, optimizer=FLAGS.optimizer, learning_rate=FLAGS.learning_rate, num_steps_per_epoch=input_app.num_steps_per_epoch, metric_eval_fn=metric_eval_fn, summary_excls=summary_excls, init_fn=init_fn, restore_fn=restore_fn, sess=sess ) # notice if use melt.constant in predictor then must pass sess