Beispiel #1
0
def train(dataset_train,
          dataset_val,
          dataset_test,
          ckptfile='',
          caffemodel=''):
    print('Training start...')
    batch_size = FLAGS.batch_size

    path = modelpath("")
    if not os.path.exists(path):
        os.makedirs(path)

    with tf.Graph().as_default():

        startstep = 0  #if not is_finetune else int(ckptfile.split('-')[-1])
        global_step = tf.Variable(startstep, trainable=False)

        # placeholders for graph input

        anchor_search = tf.placeholder('float32', shape=(None, 227, 227, 3))
        anchor_street = tf.placeholder('float32', shape=(None, 227, 227, 3))
        anchor_aerial = tf.placeholder('float32', shape=(None, 227, 227, 3))

        positive = tf.placeholder('float32', shape=(None, 227, 227, 3))
        negative = tf.placeholder('float32', shape=(None, 227, 227, 3))

        keep_prob_ = tf.placeholder('float32')

        # graph outputs
        feature_anchor = model.inference_crossview(
            [anchor_search, anchor_street, anchor_aerial], keep_prob_,
            FLAGS.feature, False)
        feature_positive = model.inference(positive, keep_prob_, FLAGS.feature)
        feature_negative = model.inference(negative, keep_prob_, FLAGS.feature)

        feature_size = tf.size(feature_anchor) / batch_size

        feature_list = model.feature_normalize(
            [feature_anchor, feature_positive, feature_negative])

        loss, d_pos, d_neg, loss_origin = model.triplet_loss(
            feature_list[0], feature_list[1], feature_list[2])

        # summary
        summary_op = tf.merge_all_summaries()

        training_loss = tf.placeholder('float32',
                                       shape=(),
                                       name='training_loss')
        training_summary = tf.scalar_summary('training_loss', training_loss)

        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.lr).minimize(
            loss)  #batch size 512
        #optimizer = tf.train.AdamOptimizer(learning_rate = 0.0000001).minimize(loss)

        #validation
        validation_loss = tf.placeholder('float32',
                                         shape=(),
                                         name='validation_loss')
        validation_summary = tf.scalar_summary('validation_loss',
                                               validation_loss)

        # test

        feature_pair_list = model.feature_normalize(
            [feature_anchor, feature_positive])

        pair_loss = model.eval_loss(feature_pair_list[0], feature_pair_list[1])
        testing_loss = tf.placeholder('float32', shape=(), name='testing_loss')
        testing_summary = tf.scalar_summary('testing_loss', testing_loss)

        init_op = tf.initialize_all_variables()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            saver = tf.train.Saver(max_to_keep=50)
            if ckptfile:
                # load checkpoint file
                saver.restore(sess, ckptfile)
                """
                sess.run(init_op)
                
                all_vars = tf.all_variables()
                cv_vars = [k for k in all_vars if k.name.startswith("cv_")]
                share_vars = [k for k in all_vars if not k.name.startswith("cv_")]

                saver_share = tf.train.Saver(share_vars)

                saver_share.restore(sess, ckptfile)
                with tf.variable_scope('fc6', reuse=True):
                    w = tf.get_variable('weights')
                    b = tf.get_variable('biases')

                with tf.variable_scope('cv_fc6', reuse=True):
                    for subkey, data in zip(('weights', 'biases'), (w, b)):
                        print 'loading cv_fc6', subkey
                        var = tf.get_variable(subkey)
                        sess.run(var.assign(data))

                """
                print 'restore variables done'
            elif caffemodel:
                # load caffemodel generated with caffe-tensorflow
                sess.run(init_op)
                model.load_alexnet(sess, caffemodel)
                print 'loaded pretrained caffemodel:', caffemodel
            else:
                # from scratch
                sess.run(init_op)
                print 'init_op done'

            summary_writer = tf.train.SummaryWriter("logs/{}/{}/{}".format(
                FLAGS.train_dir, FLAGS.feature, parameter_name),
                                                    graph=sess.graph)

            epoch = 1
            global_step = step = print_iter_sum = 0
            min_loss = min_test_loss = sys.maxint
            loss_sum = []

            while True:

                batch_x, batch_y, batch_z, isnextepoch, start, end = dataset_train.sample_path2img(
                    batch_size, True)

                step += len(batch_y)
                global_step += len(batch_y)
                print_iter_sum += len(batch_y)

                feed_dict = {
                    anchor_search: batch_x['search'],
                    anchor_street: batch_x['streetview_clean'],
                    anchor_aerial: batch_x['aerial_clean'],
                    positive: batch_y,
                    negative: batch_z,
                    keep_prob_: 0.5
                }  # dropout rate

                _, loss_value, pos_value, neg_value, origin_value, anchor_value = sess.run(
                    [
                        optimizer, loss, d_pos, d_neg, loss_origin,
                        feature_list[0]
                    ],
                    feed_dict=feed_dict)
                loss_value = np.mean(loss_value)
                loss_sum.append(loss_value)

                if print_iter_sum / print_iter >= 1:
                    loss_sum = np.mean(loss_sum)
                    print('epo{}, {}/{}, loss: {}'.format(
                        epoch, step, len(dataset_train.data), loss_sum))
                    print_iter_sum -= print_iter
                    loss_sum = []

                loss_valuee = sess.run(training_summary,
                                       feed_dict={training_loss: loss_value})

                summary_writer.add_summary(loss_valuee, global_step)
                summary_writer.flush()

                action = 0
                if FLAGS.remove and loss_value == 0:
                    action = dataset_train.remove(start, end)
                    if action == 1:
                        finish_training(saver, sess, epoch)
                        break

                if isnextepoch or action == -1:

                    val_loss_sum = []
                    isnextepoch = False  # set for validation
                    step = 0
                    print_iter_sum = 0

                    # validation
                    while not isnextepoch:

                        val_x, val_y, val_z, isnextepoch, start, end = dataset_val.sample_path2img(
                            batch_size, True)
                        val_feed_dict = {
                            anchor_search: val_x['search'],
                            anchor_street: val_x['streetview_clean'],
                            anchor_aerial: val_x['aerial_clean'],
                            positive: val_y,
                            negative: val_z,
                            keep_prob_: 1.
                        }
                        val_loss = sess.run([loss], feed_dict=val_feed_dict)
                        val_loss_sum.append(np.mean(val_loss))

                    dataset_val.reset_sample()
                    val_loss_sum = np.mean(val_loss_sum)
                    print("Validation loss: {}".format(val_loss_sum))

                    summary_val_loss_sum = sess.run(
                        validation_summary,
                        feed_dict={validation_loss: val_loss_sum})
                    summary_writer.add_summary(summary_val_loss_sum,
                                               global_step)

                    # testing
                    #IPython.embed()
                    num = 50
                    test_feed_dict = {
                        anchor_search: dataset_test[0]['search'][:num],
                        anchor_street:
                        dataset_test[0]['streetview_clean'][:num],
                        anchor_aerial: dataset_test[0]['aerial_clean'][:num],
                        positive: dataset_test[1][:num],
                        negative: dataset_test[0]['search'][:num],  # useless
                        keep_prob_: 1.
                    }

                    test_loss = sess.run([pair_loss], feed_dict=test_feed_dict)
                    test_loss = np.mean(test_loss)
                    print("Testing loss: {}".format(test_loss))
                    summary_test_loss = sess.run(
                        testing_summary, feed_dict={testing_loss: test_loss})
                    summary_writer.add_summary(summary_test_loss, global_step)

                    # ready to flush
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, global_step)
                    summary_writer.flush()
                    # save by testing
                    if min_test_loss > test_loss:
                        min_test_loss = test_loss
                        """
                        if 'best_test_path' in locals():
                            os.remove(best_test_path)
                        """
                        best_test_path = modelpath("test_{}_{}".format(
                            epoch, test_loss))
                        saver.save(sess, best_test_path)
                        print(best_test_path)

                    # save by validation
                    elif min_loss > val_loss_sum:
                        min_loss = val_loss_sum
                        """
                        if 'best_path' in locals():
                            os.remove(best_path)
                        """
                        best_path = modelpath("val_{}_{}".format(
                            epoch, val_loss_sum))
                        saver.save(sess, best_path)
                        print(best_path)
                    # save by SAVE_INTERVAL
                    elif epoch % SAVE_INTERVAL == 0:
                        path = modelpath(epoch)
                        saver.save(sess, path)
                        print(path)

                    dataset_train.reset_sample()
                    print(epoch)

                    epoch += 1
                    if epoch >= max_epo:
                        finish_training(saver, sess, epoch)
                        break
Beispiel #2
0
FLAGS = flags.FLAGS

layer_list = [FLAGS.feature]

output_root = "triplet_feature"

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

with tf.Graph().as_default(), tf.Session(config=config) as sess:

    img_input = tf.placeholder('float32', shape=(None, 227, 227, 3))

    feature = model.inference(img_input, 1, FLAGS.feature, False)

    norm_cross_pred = model.feature_normalize([feature])
    pred = norm_cross_pred[0]

    saver = tf.train.Saver()
    saver.restore(sess, FLAGS.model_dir)

    img_name = FLAGS.file.replace(".jpg", "").replace(".png", "")
    print("Load image: {}".format(img_name))
    img_name = osp.basename(img_name)

    img = cv2.imread(FLAGS.file, cv2.IMREAD_COLOR)
    img = transform_img(img, 227, 227)

    for layer in layer_list:
        output_layer = osp.join(output_root, layer)
        if not osp.exists(output_layer):
Beispiel #3
0
def train(dataset_train, dataset_val, ckptfile='', caffemodel=''):
    print('Training start...')
    is_finetune = bool(ckptfile)
    batch_size = FLAGS.batch_size

    path = modelpath("")
    if not os.path.exists(path):
        os.makedirs(path)

    with tf.Graph().as_default():

        startstep = 0  #if not is_finetune else int(ckptfile.split('-')[-1])
        global_step = tf.Variable(startstep, trainable=False)

        # placeholders for graph input

        anchor = tf.placeholder('float32', shape=(None, 227, 227, 3))
        positive = tf.placeholder('float32', shape=(None, 227, 227, 3))
        negative = tf.placeholder('float32', shape=(None, 227, 227, 3))

        keep_prob_ = tf.placeholder('float32')

        # graph outputs
        feature_anchor = model.inference(anchor, keep_prob_, FLAGS.feature,
                                         False)
        feature_positive = model.inference(positive, keep_prob_, FLAGS.feature)
        feature_negative = model.inference(negative, keep_prob_, FLAGS.feature)

        feature_size = tf.size(feature_anchor) / batch_size

        feature_list = model.feature_normalize(
            [feature_anchor, feature_positive, feature_negative])

        loss, d_pos, d_neg, loss_origin = model.triplet_loss(
            feature_list[0], feature_list[1], feature_list[2])

        # summary
        summary_op = tf.merge_all_summaries()

        training_loss = tf.placeholder('float32',
                                       shape=(),
                                       name='training_loss')
        training_summary = tf.scalar_summary('training_loss', training_loss)

        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.lr).minimize(
            loss)  #batch size 512

        #validation
        validation_loss = tf.placeholder('float32',
                                         shape=(),
                                         name='validation_loss')
        validation_summary = tf.scalar_summary('validation_loss',
                                               validation_loss)

        init_op = tf.initialize_all_variables()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            saver = tf.train.Saver(max_to_keep=max_epo)
            if ckptfile:
                # load checkpoint file
                saver.restore(sess, ckptfile)
                print('restore variables done')
            elif caffemodel:
                # load caffemodel generated with caffe-tensorflow
                sess.run(init_op)
                model.load_alexnet(sess, caffemodel)
                print('loaded pretrained caffemodel:{}'.format(caffemodel))
            else:
                # from scratch
                sess.run(init_op)
                print('init_op done')

            summary_writer = tf.train.SummaryWriter("logs/{}/{}/{}".format(
                FLAGS.train_dir, FLAGS.feature, parameter_name),
                                                    graph=sess.graph)

            epoch = 1
            global_step = step = print_iter_sum = 0
            min_loss = min_test_loss = sys.maxint
            loss_sum = []

            while True:

                batch_x, batch_y, batch_z, isnextepoch, start, end = dataset_train.sample_path2img(
                    batch_size)

                step += len(batch_x)
                global_step += len(batch_x)
                print_iter_sum += len(batch_x)

                feed_dict = {
                    anchor: batch_x,
                    positive: batch_y,
                    negative: batch_z,
                    keep_prob_: FLAGS.dropout
                }  # dropout rate

                _, loss_value, pos_value, neg_value, origin_value, anchor_value = sess.run(
                    [
                        optimizer, loss, d_pos, d_neg, loss_origin,
                        feature_list[0]
                    ],
                    feed_dict=feed_dict)
                loss_value = np.mean(loss_value)
                loss_sum.append(loss_value)

                if print_iter_sum / print_iter >= 1:
                    loss_sum = np.mean(loss_sum)
                    print('epo{}, {}/{}, loss: {}'.format(
                        epoch, step, len(dataset_train.data), loss_sum))
                    print_iter_sum -= print_iter
                    loss_sum = []

                loss_valuee = sess.run(training_summary,
                                       feed_dict={training_loss: loss_value})

                summary_writer.add_summary(loss_valuee, global_step)
                summary_writer.flush()

                action = 0
                if FLAGS.remove and loss_value == 0:
                    action = dataset_train.remove(start, end)
                    if action == 1:
                        finish_training(saver, sess, epoch)
                        break

                if isnextepoch or action == -1:

                    val_loss_sum = []
                    isnextepoch = False  # set for validation
                    step = 0
                    print_iter_sum = 0

                    # validation
                    while not isnextepoch:

                        val_x, val_y, val_z, isnextepoch, start, end = dataset_val.sample_path2img(
                            batch_size)
                        val_feed_dict = {
                            anchor: val_x,
                            positive: val_y,
                            negative: val_z,
                            keep_prob_: 1.
                        }
                        val_loss = sess.run([loss], feed_dict=val_feed_dict)
                        val_loss_sum.append(np.mean(val_loss))

                    dataset_val.reset_sample()
                    val_loss_sum = np.mean(val_loss_sum)
                    print("Validation loss: {}".format(val_loss_sum))

                    summary_val_loss_sum = sess.run(
                        validation_summary,
                        feed_dict={validation_loss: val_loss_sum})
                    summary_writer.add_summary(summary_val_loss_sum,
                                               global_step)

                    # ready to flush
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, global_step)
                    summary_writer.flush()

                    # save by validation
                    if min_loss > val_loss_sum:
                        min_loss = val_loss_sum
                        best_path = modelpath("val_{}_{}".format(
                            epoch, val_loss_sum))
                        saver.save(sess, best_path)
                        print(best_path)

                    # save by SAVE_INTERVAL
                    elif epoch % SAVE_INTERVAL == 0:
                        path = modelpath(epoch)
                        saver.save(sess, path)
                        print(path)

                    dataset_train.reset_sample()
                    print(epoch)

                    epoch += 1
                    if epoch >= max_epo:
                        finish_training(saver, sess, epoch)
                        break