def tower_loss(trainer, input_app=None, input_results=None):
    if input_app is None:
        input_app = InputApp.InputApp()
    if input_results is None:
        input_results = input_app.gen_input(train_only=True)

    #--------train
    weights = None
    if not FLAGS.use_weights:
        image_name, image_feature, text, text_str = input_results[
            input_app.input_train_name]
    else:
        image_name, image_feature, text, text_str, weights = input_results[
            input_app.input_train_name]

    global gtrain_image_name  #for rl
    gtrain_image_name = image_name

    with tf.device('/gpu:0'):
        main_loss = trainer.build_train_graph(image_feature, text)

    weights = None
    if not FLAGS.use_weights:
        image_name, image, label, text_str = input_results[
            input_app.scene_input_train_name]
    else:
        image_name, image, label, text_str, weights = input_results[
            input_app.scene_input_train_name]

    with tf.device('/gpu:1'):
        scene_loss = trainer.build_scene_graph(image, label)

    loss = main_loss + 0.1 * scene_loss
    return loss, main_loss, scene_loss
Beispiel #2
0
def train_process(trainer, predictor=None):
    input_app = InputApp.InputApp()
    input_results = input_app.gen_input()

    with tf.variable_scope(FLAGS.main_scope) as scope:
        ops, gen_feed_dict, deal_results = gen_train(input_app, input_results,
                                                     trainer)

        scope.reuse_variables()

        #saving predict graph, so later can direclty predict without building from scratch
        #also used in gen validate if you want to use direclty predict as evaluate per epoch
        if predictor is not None and FLAGS.gen_predict:
            gen_predict_graph(predictor)

        eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate(
            input_app, input_results, trainer, predictor)

        metric_eval_fn = None
        if FLAGS.metric_eval:
            #generative can do this also but it is slow so just ingore this
            if not algos_factory.is_generative(FLAGS.algo):
                metric_eval_fn = lambda: evaluator.evaluate_scores(predictor,
                                                                   random=True)

    init_fn = None
    summary_excls = None
    if not FLAGS.pre_calc_image_feature:
        init_fn = melt.image.create_image_model_init_fn(
            FLAGS.image_model_name, FLAGS.image_checkpoint_file)

        if predictor is not None and FLAGS.gen_predict:
            #need to excl InceptionV3 summarys why inceptionV3 op might need image_feature_feed if gen_predict
            #gen_eval_feed_dict = lambda: {predictor.image_feature_feed: [melt.image.read_image(FLAGS.one_image)]}
            #gen_eval_feed_dict = lambda: {predictor.image_feature_feed: ['']}
            summary_excls = [FLAGS.image_model_name]

    melt.print_global_varaiables()
    melt.apps.train_flow(
        ops,
        gen_feed_dict_fn=gen_feed_dict,
        deal_results_fn=deal_results,
        eval_ops=eval_ops,
        gen_eval_feed_dict_fn=gen_eval_feed_dict,
        deal_eval_results_fn=deal_eval_results,
        optimizer=FLAGS.optimizer,
        learning_rate=FLAGS.learning_rate,
        num_steps_per_epoch=input_app.num_steps_per_epoch,
        model_dir=FLAGS.model_dir,
        metric_eval_fn=metric_eval_fn,
        summary_excls=summary_excls,
        init_fn=init_fn,
        sess=sess
    )  #notice if use melt.constant in predictor then must pass sess
Beispiel #3
0
def tower_loss(trainer, input_app=None, input_results=None):
    if input_app is None:
        input_app = InputApp.InputApp()
    if input_results is None:
        input_results = input_app.gen_input(train_only=True)

    #--------train
    image_name, image_feature, text, text_str = input_results[
        input_app.input_train_name]
    #--------train neg
    if input_app.input_train_neg_name in input_results:
        neg_text, neg_text_str = input_results[input_app.input_train_neg_name]
    else:
        neg_text, neg_text_str = None, None

    loss = trainer.build_train_graph(image_feature, text, neg_text)
    return loss
Beispiel #4
0
def tower_loss(trainer, input_app=None, input_results=None):
    if input_app is None:
        input_app = InputApp.InputApp()
    if input_results is None:
        input_results = input_app.gen_input(train_only=True)

    #--------train
    image_name, image_feature, text, text_str = input_results[
        input_app.input_train_name]

    #--------train neg
    if input_results[input_app.input_train_neg_name]:
        neg_image_name, neg_image_feature, neg_text, neg_text_str = input_results[
            input_app.input_train_neg_name]

    if not FLAGS.neg_left:
        neg_image_feature = None
    if not FLAGS.neg_right:
        neg_text = None
    loss = trainer.build_train_graph(image_feature, text, neg_image_feature,
                                     neg_text)
    return loss
Beispiel #5
0
def train_process(trainer, predictor=None):
    input_app = InputApp.InputApp()
    input_results = input_app.gen_input()

    with tf.variable_scope('run') as scope:
        ops, gen_feed_dict, deal_results = gen_train(input_app, input_results,
                                                     trainer)

        scope.reuse_variables()

        #saving predict graph, so later can direclty predict without building from scratch
        #also used in gen validate if you want to use direclty predict as evaluate per epoch
        if predictor is not None:
            gen_predict_graph(predictor)

        eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate(
            input_app, input_results, trainer, predictor)

        metric_eval_function = None
        if FLAGS.metric_eval:
            #generative can do this also but it is slow so just ingore this
            if not algos_factory.is_generative(FLAGS.algo):
                metric_eval_function = lambda: evaluator.evaluate_scores(
                    predictor, random=True)

    melt.apps.train_flow(
        ops,
        gen_feed_dict=gen_feed_dict,
        deal_results=deal_results,
        eval_ops=eval_ops,
        gen_eval_feed_dict=gen_eval_feed_dict,
        deal_eval_results=deal_eval_results,
        optimizer=FLAGS.optimizer,
        learning_rate=FLAGS.learning_rate,
        num_steps_per_epoch=input_app.num_steps_per_epoch,
        model_dir=FLAGS.model_dir,
        metric_eval_function=metric_eval_function,
        sess=sess
    )  #notice if use melt.constant in predictor then must pass sess
Beispiel #6
0
def train():
    input_app = InputApp.InputApp()
    input_results = input_app.gen_input()

    with tf.variable_scope(FLAGS.main_scope) as scope:
        trainer, predictor = algos_factory.gen_trainer_and_predictor(
            FLAGS.algo)
        logging.info('trainer:{}'.format(trainer))
        logging.info('predictor:{}'.format(predictor))

        ops, gen_feed_dict, deal_results = gen_train(input_app, input_results,
                                                     trainer)

        scope.reuse_variables()
        algos_factory.set_eval_mode(trainer)

        #saving predict graph, so later can direclty predict without building from scratch
        #also used in gen validate if you want to use direclty predict as evaluate per epoch
        if predictor is not None and FLAGS.gen_predict:
            gen_predict_graph(predictor)

        eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate(
            input_app, input_results, trainer, predictor)

        metric_eval_fn = None
        if FLAGS.metric_eval:
            eval_rank = FLAGS.eval_rank and (not algos_factory.is_generative(
                FLAGS.algo) or FLAGS.assistant_model_dir)
            eval_translation = FLAGS.eval_translation and algos_factory.is_generative(
                FLAGS.algo)
            metric_eval_fn = lambda: evaluator.evaluate(predictor,
                                                        random=True,
                                                        eval_rank=eval_rank,
                                                        eval_translation=
                                                        eval_translation)

    init_fn = None
    restore_fn = None
    summary_excls = None

    if not FLAGS.pre_calc_image_feature:
        init_fn = melt.image.image_processing.create_image_model_init_fn(
            FLAGS.image_model_name, FLAGS.image_checkpoint_file)
        if melt.checkpoint_exists_in(FLAGS.model_dir):
            if not melt.varname_in_checkpoint(FLAGS.image_model_name,
                                              FLAGS.model_dir):
                restore_fn = init_fn

    #melt.print_global_varaiables()
    melt.apps.train_flow(
        ops,
        gen_feed_dict_fn=gen_feed_dict,
        deal_results_fn=deal_results,
        eval_ops=eval_ops,
        gen_eval_feed_dict_fn=gen_eval_feed_dict,
        deal_eval_results_fn=deal_eval_results,
        optimizer=FLAGS.optimizer,
        learning_rate=FLAGS.learning_rate,
        num_steps_per_epoch=input_app.num_steps_per_epoch,
        model_dir=FLAGS.model_dir,
        metric_eval_fn=metric_eval_fn,
        summary_excls=summary_excls,
        init_fn=init_fn,
        restore_fn=restore_fn,
        sess=sess
    )  #notice if use melt.constant in predictor then must pass sess
def train():
    input_app = InputApp.InputApp()
    input_results = input_app.gen_input()
    global_scope = melt.apps.train.get_global_scope()

    with tf.variable_scope(global_scope) as global_scope:
        with tf.variable_scope(FLAGS.main_scope) as scope:
            trainer, validator, predictor = algos_factory.gen_all(FLAGS.algo)

            if FLAGS.reinforcement_learning:
                from deepiu.image_caption.algos.show_and_tell import RLInfo
                trainer.rl = RLInfo()

            if FLAGS.scene_model:
                #TODO scene model not support show_eval right now due to batch size eg 3 for show_eval
                from deepiu.image_caption.algos.show_and_tell import SceneInfo
                trainer.scene = SceneInfo(FLAGS.batch_size)
                validator.scene = SceneInfo(FLAGS.eval_batch_size)
                predictor.scene = SceneInfo()

            logging.info('trainer:{}'.format(trainer))
            logging.info('predictor:{}'.format(predictor))

            ops, gen_feed_dict, deal_results = gen_train(
                input_app, input_results, trainer)

            scope.reuse_variables()

            #saving predict graph, so later can direclty predict without building from scratch
            #also used in gen validate if you want to use direclty predict as evaluate per epoch
            if predictor is not None and FLAGS.gen_predict:
                gen_predict_graph(predictor)

            eval_ops, gen_eval_feed_dict, deal_eval_results = gen_validate(
                input_app, input_results, validator, predictor)

            metric_eval_fn = None
            if FLAGS.metric_eval:
                if FLAGS.scene_model:
                    evaluator.gen_feed_dict_fn = lambda image_features: gen_predict_scene_feed_dict(
                        predictor, image_features)

                eval_rank = FLAGS.eval_rank and (
                    not algos_factory.is_generative(FLAGS.algo)
                    or FLAGS.assistant_model_dir)
                eval_translation = FLAGS.eval_translation and algos_factory.is_generative(
                    FLAGS.algo)
                metric_eval_fn = lambda: evaluator.evaluate(
                    predictor,
                    random=True,
                    eval_rank=eval_rank,
                    eval_translation=eval_translation)

    # NOTCIE in empty scope now, image model need to escape all scopes!
    summary_excls = None
    init_fn, restore_fn = image_util.get_init_restore_fn()

    with tf.variable_scope(global_scope):
        #melt.print_global_varaiables()
        melt.apps.train_flow(
            ops,
            names=names,
            gen_feed_dict_fn=gen_feed_dict,
            deal_results_fn=deal_results,
            eval_ops=eval_ops,
            eval_names=eval_names,
            gen_eval_feed_dict_fn=gen_eval_feed_dict,
            deal_eval_results_fn=deal_eval_results,
            optimizer=FLAGS.optimizer,
            learning_rate=FLAGS.learning_rate,
            num_steps_per_epoch=input_app.num_steps_per_epoch,
            metric_eval_fn=metric_eval_fn,
            summary_excls=summary_excls,
            init_fn=init_fn,
            restore_fn=restore_fn,
            sess=sess
        )  # notice if use melt.constant in predictor then must pass sess