Beispiel #1
0
def main(argv):
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size

    with tf.Session() as sess:
        if FLAGS.phase == 'train':
            # training phase
            data = prepare_train_data(config)
            model = CaptionGenerator(config)
            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)
            if FLAGS.load_cnn:
                model.load_cnn(sess, FLAGS.cnn_model_file)
            tf.get_default_graph().finalize()
            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            # evaluation phase
            coco, data, vocabulary = prepare_eval_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

        else:
            # testing phase
            data, vocabulary = prepare_test_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)
Beispiel #2
0
def main(argv):
    flags = tf.app.flags
    FLAGS = flags.FLAGS
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size

    # Cluster One setting
    clusterone_dist_env = distributed_env(config.root_path_to_local_data,
                                          config.path_to_local_logs,
                                          config.cloud_path_to_data,
                                          config.local_repo,
                                          config.cloud_user_repo, flags)

    clusterone_dist_env.get_env()

    tf.reset_default_graph()
    device, target = clusterone_dist_env.device_and_target(
    )  # getting node environment
    # end of setting

    # Using tensorflow's MonitoredTrainingSession to take care of checkpoints
    with tf.train.MonitoredTrainingSession(
            master=target,
            is_chief=(FLAGS.task_index == 0),
            checkpoint_dir=FLAGS.log_dir) as sess:

        #     with tf.Session() as sess:
        if FLAGS.phase == 'train':
            # training phase
            data = prepare_train_data(config)
            with tf.device(device):  # define model
                model = CaptionGenerator(config)
            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)
            if FLAGS.load_cnn:
                model.load_cnn(sess, FLAGS.cnn_model_file)
            tf.get_default_graph().finalize()
            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            # evaluation phase
            config.batch_size = 1
            coco, data, vocabulary = prepare_eval_data(config)
            with tf.device(device):  # define model
                model = CaptionGenerator(config)
                model.load(sess, FLAGS.model_file)
                tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

        else:
            # testing phase
            data, vocabulary = prepare_test_data(config)
            with tf.device(device):  # define model
                model = CaptionGenerator(config)
                model.load(sess, FLAGS.model_file)
                tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)
def evaluate(config):
    """Orchestrates the evaluation process.
    
    This method is responsible of executing all the steps required to train a new model, 
    which includes:
    - Preparing the dataset
    - Building the model
    - Fitting the model to the data

    Arguments:
        config (util.Config): Values for various configuration options
    """

    eval_dataset, vocabulary, coco_eval = prepare_eval_data(config)
    model = build_model(config, vocabulary)
    # start = time.time()
    # results = eval(model, eval_dataset, vocabulary, config)
    # logging.info('Total caption generation time: %d seconds', time.time() - start)

    # Evaluate these captions
    start = time.time()
    coco_eval_result = coco_eval.loadRes(config.eval_result_file)
    scorer = COCOEvalCap(coco_eval, coco_eval_result)
    scorer.evaluate()
    logging.info("Evaluation complete.")
    logging.info('Total evaluation time: %d seconds', time.time() - start)
Beispiel #4
0
def main(argv):
    config = Config()

    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size
    config.trainable_variable = FLAGS.train_cnn

    np.random.seed(config.seed)
    tf.random.set_random_seed(config.seed)

    model = SeqGAN(config)  #initializes seqgan model

    with tf.Session() as sess:
        print("FLAG: " + FLAGS.phase)
        if FLAGS.phase == 'train':
            # training phase

            data = prepare_train_data(config)

            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)

            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            print("EVALUATING")

            eval_data = prepare_eval_data(config)
            model.load(sess, FLAGS.model_file)

            model.eval(sess, eval_data)

        elif FLAGS.phase == 'test_loaded_cnn':  #test phase doesn't work yet
            # testing only cnn
            sess.run(tf.global_variables_initializer())
            imgs = tf.placeholder(tf.float32, [None, 224, 224, 3])
            probs = model.test_cnn(imgs)
            model.load_cnn(sess, FLAGS.cnn_model_file)

            img1 = imread(FLAGS.image_file, mode='RGB')
            img1 = imresize(img1, (224, 224))

            prob = sess.run(probs, feed_dict={imgs: [img1]})[0]
            preds = (np.argsort(prob)[::-1])[0:5]
            for p in preds:
                print(class_names[p], prob[p])

        else:
            # testing phase
            test_data = prepare_test_data(config)
            model.load(sess, FLAGS.model_file)
            #tf.get_default_graph().finalize()
            model.test(sess, test_data)
Beispiel #5
0
def main(argv):
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size

    with tf.compat.v1.Session() as sess:
        if FLAGS.phase == 'train':
            # training phase
            config.train_image_dir = config.train_image_dir[:
                                                            -1] + "_" + FLAGS.namedir + "/"

            data = prepare_train_data(config)
            model = CaptionGenerator(config)
            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)
            if FLAGS.load_cnn:
                model.load_cnn(sess, FLAGS.cnn_model_file)
            tf.get_default_graph().finalize()
            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            # evaluation phase
            config.eval_image_dir = config.eval_image_dir[:
                                                          -1] + "_" + FLAGS.namedir + "/"
            config.eval_result_dir = config.eval_result_dir[:
                                                            -1] + "_" + FLAGS.namedir + "/"
            config.eval_result_file = config.eval_result_file[:
                                                              -5] + "_" + FLAGS.namedir + config.eval_result_file[
                                                                  -5:]  # .json

            coco, data, vocabulary = prepare_eval_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

        else:
            # testing phase
            config.test_image_dir = config.test_image_dir[:
                                                          -1] + "_" + FLAGS.namedir + "/"
            config.test_result_dir = config.test_result_dir[:
                                                            -1] + "_" + FLAGS.namedir + "/"
            config.test_result_file = config.test_result_file[:
                                                              -4] + "_" + FLAGS.namedir + config.test_result_file[
                                                                  -4:]  # .csv

            data, vocabulary = prepare_test_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.compat.v1.get_default_graph().finalize()
            model.test(sess, data, vocabulary)
Beispiel #6
0
def main(argv):
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size
    config.trainable_variable = FLAGS.train_cnn

    with tf.Session() as sess:
        if FLAGS.phase == 'train':
            # training phase
            data = prepare_train_data(config)
            model = CaptionGenerator(config)
            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)
            #load the cnn file
            if FLAGS.load_cnn:
                model.load_cnn(sess, FLAGS.cnn_model_file)
            tf.get_default_graph().finalize()
            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            # evaluation phase
            coco, data, vocabulary = prepare_eval_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

        elif FLAGS.phase == 'test_loaded_cnn':
            # testing only cnn
            model = CaptionGenerator(config)
            sess.run(tf.global_variables_initializer())
            imgs = tf.placeholder(tf.float32, [None, 224, 224, 3])
            probs = model.test_cnn(imgs)
            model.load_cnn(sess, FLAGS.cnn_model_file)

            img1 = imread(FLAGS.image_file, mode='RGB')
            img1 = imresize(img1, (224, 224))

            prob = sess.run(probs, feed_dict={imgs: [img1]})[0]
            preds = (np.argsort(prob)[::-1])[0:5]
            for p in preds:
                print(class_names[p], prob[p])

        else:
            # testing phase
            data, vocabulary = prepare_test_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)
Beispiel #7
0
def main(argv):
    os.system("ls /tinysrc")
    os.system("python tinysrc/download_flickr8k.py")
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.joint_train = FLAGS.joint_train
    config.beam_size = FLAGS.beam_size
    config.attention_mechanism = FLAGS.attention
    config.faster_rcnn_frozen = FLAGS.faster_rcnn_frozen

    with tf.Session() as sess:
        if FLAGS.phase == 'train':
            # training phase
            data = prepare_train_data(config)
            model = CaptionGenerator(config)
            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)
            if FLAGS.load_cnn:
                model.load_faster_rcnn_feature_extractor(
                    sess, FLAGS.faster_rcnn_ckpt)
            tf.get_default_graph().finalize()
            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            # evaluation phase
            coco, data, vocabulary = prepare_eval_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

        else:
            # testing phase
            data, vocabulary = prepare_test_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)

    os.system("rm -rf /output/Flickr8k_Dataset/")
    os.system("rm -rf /output/Flickr8k_text/")
Beispiel #8
0
def main(argv):
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size
    checkpoint_dir = config.checkpoint_dir
    save_checkpoint_secs = config.save_checkpoint_secs
    save_checkpoint_steps = config.save_checkpoint_steps

    global_step = tf.train.get_or_create_global_step()
    checkpoint_step = tf.assign_add(global_step, 1)

    model = CaptionGenerator(config)

    # with tf.Session() as sess:
    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=checkpoint_dir,
            save_checkpoint_steps=save_checkpoint_steps,
    ) as sess:
        if FLAGS.phase == 'train':
            # training phase
            data = prepare_train_data(config)
            # WIP modify load part
            # if FLAGS.load:
            #     model.load(sess, FLAGS.model_file)
            # if FLAGS.load_cnn:
            #     model.load_cnn(sess, FLAGS.cnn_model_file)
            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            # evaluation phase
            coco, data, vocabulary = prepare_eval_data(config)
            tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

        else:
            # testing phase
            data, vocabulary = prepare_test_data(config)
            tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)
Beispiel #9
0
def main(argv):
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)

    with tf.Session(config=tf.ConfigProto(log_device_placement=False,
                                          gpu_options=gpu_options)) as sess:
        if FLAGS.phase == 'train':
            # training phase
            data = prepare_train_data(config)
            model = CaptionGenerator(config)
            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)
            if FLAGS.load_cnn:
                model.load_cnn(sess, FLAGS.cnn_model_file)
            tf.get_default_graph().finalize()
            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            # evaluation phase
            coco, data, vocabulary = prepare_eval_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

        else:
            # testing phase
            data, vocabulary = prepare_test_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)
Beispiel #10
0
def main(_):
    config = Config()
    config.mode = FLAGS.mode
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    # 设置按需分配GPU
    with tf.Session(config=tf_config) as sess:
        if FLAGS.mode == 'train':
            # training mode
            data = prepare_train_data(config)
            model = CaptionGenerator(config)
            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)
            if FLAGS.load_cnn:
                model.load_cnn(sess, FLAGS.cnn_model_file)
            tf.get_default_graph().finalize()
            model.train(sess, data)

        elif FLAGS.mode == 'eval':
            # evaluation mode
            coco, data, vocabulary = prepare_eval_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

        else:
            # testing mode
            data, vocabulary = prepare_test_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)
Beispiel #11
0
def main(argv):

    start_time = time.time()
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size
    config.distributed = FLAGS.distributed
    config.test_image_dir = os.path.join(FLAGS.node_root, 'images')
    config.test_result_dir = os.path.join(FLAGS.node_root, 'results')
    config.test_result_file = os.path.join(FLAGS.node_root, 'results.cvs')
    config.replicas = len(FLAGS.worker_hosts.split(","))
    if FLAGS.task_index == '':
        config.task_index = 0
    else:
        config.task_index = int(FLAGS.task_index)

    if FLAGS.phase == 'train':
        # training phase

        if FLAGS.distributed:
            config.train_image_dir = FLAGS.input_path

            ps_hosts = FLAGS.ps_hosts.split(",")

            worker_hosts = FLAGS.worker_hosts.split(",")

            # Create a cluster from the parameter server and worker hosts.
            cluster = tf.train.ClusterSpec({
                "ps": ps_hosts,
                "worker": worker_hosts
            })

            # Create and start a server for the local task.
            server = tf.train.Server(cluster,
                                     job_name=FLAGS.job_name,
                                     task_index=config.task_index)

            #with tf.device(tf.train.replica_device_setter(cluster=cluster)):
            #                global_step = tf.Variable(0)

            #with tf.device("/job:ps/task:0"):
            #	global_step = tf.Variable(0, name="global_step")

            if FLAGS.job_name == "ps":
                server.join()
            elif FLAGS.job_name == "worker":
                with tf.device(
                        tf.train.replica_device_setter(
                            worker_device="/job:worker/task:%d" %
                            config.task_index,
                            cluster=cluster)):

                    model = CaptionGenerator(config)
                    data = prepare_train_data(config)

                    init_op = tf.initialize_all_variables()
                    print "Variables Initialized ..."

                begin = time.time()
                #The StopAtStepHook handles stopping after running given steps.
                hooks = [tf.train.StopAtStepHook(num_steps=1200000)]

                # The MonitoredTrainingSession takes care of session initialization,
                # restoring from a checkpoint, saving to a checkpoint, and closing when done
                # or an error occurs.
                with tf.train.MonitoredTrainingSession(
                        master=server.target,
                        is_chief=(config.task_index == 0),
                        checkpoint_dir=
                        "/home/mauro.emc/image_captioning/models",
                        hooks=hooks) as mon_sess:

                    if not os.path.exists(config.summary_dir):
                        os.mkdir(config.summary_dir)

                    train_writer = tf.summary.FileWriter(
                        config.summary_dir, mon_sess.graph)

                    print "Start the model training"

                    #while not mon_sess.should_stop():
                    model.train(mon_sess, data, train_writer,
                                config.task_index)

                    train_writer.close()
                    print "Model stopped train"

            print("Train completed")
            print("Total Time in secs: " + str(time.time() - begin))

        else:
            with tf.Session() as sess:
                data = prepare_train_data(config)
                model = CaptionGenerator(config)
                sess.run(tf.global_variables_initializer())
                if FLAGS.load:
                    model.load(sess, FLAGS.model_file)
                if FLAGS.load_cnn:
                    model.load_cnn(sess, FLAGS.cnn_model_file)
                tf.get_default_graph().finalize()
                model.train(sess, data)

    elif FLAGS.phase == 'eval':
        with tf.Session() as sess:
            # evaluation phase
            coco, data, vocabulary = prepare_eval_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

    else:
        with tf.Session() as sess:
            # testing phase
            data, vocabulary = prepare_test_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)
    print 'Total time in seconds :   ' + str(time.time() - start_time)
Beispiel #12
0
def main(argv):
    start_time = time.time()
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size
    config.distributed = FLAGS.distributed
    config.test_image_dir = os.path.join(FLAGS.node_root, 'images')
    config.test_result_dir = os.path.join(FLAGS.node_root, 'results')
    config.test_result_file = os.path.join(FLAGS.node_root, 'results.cvs')
    config.replicas = len(FLAGS.worker_hosts.split(","))
    config.task_index = FLAGS.task_index

    if FLAGS.phase == 'train':
        # training phase

        if FLAGS.distributed:
            config.train_image_dir = FLAGS.input_path
            print config.train_image_dir

            ps_hosts = FLAGS.ps_hosts.split(",")
            worker_hosts = FLAGS.worker_hosts.split(",")

            # Create a cluster from the parameter server and worker hosts.
            cluster = tf.train.ClusterSpec({
                "ps": ps_hosts,
                "worker": worker_hosts
            })

            # Create and start a server for the local task.
            server = tf.train.Server(cluster,
                                     job_name=FLAGS.job_name,
                                     task_index=FLAGS.task_index)

            if FLAGS.job_name == "ps":
                server.join()
            elif FLAGS.job_name == "worker":
                with tf.device(
                        tf.train.replica_device_setter(
                            worker_device="/job:worker/task:%d" %
                            FLAGS.task_index,
                            cluster=cluster)):

                    tf.reset_default_graph()

                    global_step = tf.get_variable(
                        'global_step', [],
                        initializer=tf.constant_initializer(0),
                        trainable=False,
                        dtype=tf.int32)

                    data = prepare_train_data(config)
                    model = CaptionGenerator(config)

                    init_op = tf.initialize_all_variables()

                is_chief = (FLAGS.task_index == 0)
                # Create a "supervisor", which oversees the training process.
                sv = tf.train.Supervisor(
                    is_chief=is_chief,
                    logdir="/home/mauro.emc/image_captioning/tmp/logs",
                    init_op=init_op,
                    global_step=global_step,
                    save_model_secs=600)
                with sv.prepare_or_wait_for_session(server.target) as sess:
                    if is_chief:
                        sv.start_queue_runners(sess, [chief_queue_runner])
                        # Insert initial tokens to the queue.
                        sess.run(init_token_op)
                    sess.run(tf.global_variables_initializer())
                    model.train(sess, data)
                sv.stop()
        else:
            with tf.Session() as sess:
                data = prepare_train_data(config)
                model = CaptionGenerator(config)
                sess.run(tf.global_variables_initializer())
                if FLAGS.load:
                    model.load(sess, FLAGS.model_file)
                if FLAGS.load_cnn:
                    model.load_cnn(sess, FLAGS.cnn_model_file)
                tf.get_default_graph().finalize()
                model.train(sess, data)

    elif FLAGS.phase == 'eval':
        with tf.Session() as sess:
            # evaluation phase
            coco, data, vocabulary = prepare_eval_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

    else:
        with tf.Session() as sess:
            # testing phase
            data, vocabulary = prepare_test_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)
    print 'Total time in seconds :   ' + str(time.time() - start_time)
Beispiel #13
0
def main(argv):
    config = Config()
    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size

    with tf.Session() as sess:
        if FLAGS.phase == 'train':
            # training phase
            data = prepare_train_data(config)
            model = CaptionGenerator(config)
            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)
            if FLAGS.load_cnn:
                model.load_cnn(sess, FLAGS.cnn_model_file)
            tf.get_default_graph().finalize()
            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            # evaluation phase
            coco, data, vocabulary = prepare_eval_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval(sess, coco, data, vocabulary)

        elif FLAGS.phase == 'test_new_data':
            # evaluation phase
            coco, data, vocabulary = prepare_eval_new_data(
                config.eval_caption_file_unsplash, config.eval_image_unsplash,
                config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval_new_data(sess, coco, data, vocabulary,
                                config.eval_result_dir_unsplash,
                                config.eval_result_file_unsplash)

        elif FLAGS.phase == 'test_new_data_vizwiz':
            # evaluation phase
            coco, data, vocabulary = prepare_eval_new_data(
                config.eval_caption_file_vizwiz_train,
                config.eval_image_vizwiz_train, config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval_new_data(sess, coco, data, vocabulary,
                                config.eval_result_dir_vizwiz_train,
                                config.eval_result_file_vizwiz_train)

        elif FLAGS.phase == 'test_new_data_insta':
            # evaluation phase
            coco, data, vocabulary = prepare_eval_new_data(
                config.eval_caption_file_insta, config.eval_image_insta,
                config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval_new_data(sess, coco, data, vocabulary,
                                config.eval_result_dir_insta,
                                config.eval_result_file_insta)

        elif FLAGS.phase == 'test_new_data_google_top_n':
            # evaluation phase
            coco, data, vocabulary = prepare_eval_new_data(
                config.eval_caption_file_topN, config.eval_image_topN, config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.eval_new_data(sess, coco, data, vocabulary,
                                config.eval_result_dir_topN,
                                config.eval_result_file_topN)

        else:
            # testing phase
            data, vocabulary = prepare_test_data(config)
            model = CaptionGenerator(config)
            model.load(sess, FLAGS.model_file)
            tf.get_default_graph().finalize()
            model.test(sess, data, vocabulary)
Beispiel #14
0
def main(argv):
    config = Config()

    config.phase = FLAGS.phase
    config.train_cnn = FLAGS.train_cnn
    config.beam_size = FLAGS.beam_size
    config.trainable_variable = FLAGS.train_cnn

    np.random.seed(config.seed)
    tf.random.set_random_seed(config.seed)

    if FLAGS.model == 'SeqGAN':
        model = SeqGAN(config)
    elif FLAGS.model == 'Show&Tell':
        model = CaptionGenerator(config)
    else:
        model = SeqGAN(config)

    with tf.Session() as sess:
        print("FLAG: " + FLAGS.phase)
        if FLAGS.phase == 'train':
            # training phase
            #image_path = 'D:/download/COCO/train/images/COCO_train2014_000000318556.jpg'
            #image_loader = ImageLoader('./utils/ilsvrc_2012_mean.npy')

            #data = prepare_train_data(config)

            sess.run(tf.global_variables_initializer())
            if FLAGS.load:
                model.load(sess, FLAGS.model_file)
            #load the cnn file
            #if FLAGS.load_cnn:
            #model.load_cnn(sess, FLAGS.cnn_model_file)
            #model.load_vgg()
            #tf.get_default_graph().finalize()
            model.train(sess, data)

        elif FLAGS.phase == 'eval':
            print("EVALUATING")
            # evaluation phase
            eval_data = prepare_eval_data(config)
            #model.load(sess, FLAGS.model_file)
            #sess.run(tf.global_variables_initializer())
            #tf.get_default_graph().finalize()
            #saver = tf.train.Saver(tf.global_variables())
            #saver.restore(sess,FlAGS.model_file)
            model.eval(sess, eval_data)

        elif FLAGS.phase == 'test_loaded_cnn':
            # testing only cnn
            sess.run(tf.global_variables_initializer())
            imgs = tf.placeholder(tf.float32, [None, 224, 224, 3])
            probs = model.test_cnn(imgs)
            model.load_cnn(sess, FLAGS.cnn_model_file)

            img1 = imread(FLAGS.image_file, mode='RGB')
            img1 = imresize(img1, (224, 224))

            prob = sess.run(probs, feed_dict={imgs: [img1]})[0]
            preds = (np.argsort(prob)[::-1])[0:5]
            for p in preds:
                print(class_names[p], prob[p])

        else:
            # testing phase
            test_data = prepare_test_data(config)
            model.load(sess, FLAGS.model_file)
            #tf.get_default_graph().finalize()
            model.test(sess, test_data)