# -*- coding: utf-8 -*-

import sys
sys.path.append('..')

import numpy as np
from model import triplet_loss

if __name__ == '__main__':
    anchor = np.random.randn(1, 128)
    positive = np.random.randn(1, 128)
    negative = np.random.randn(1, 128)
    print(triplet_loss(anchor, positive, negative))
Esempio n. 2
0
def train(opt):
    writer = SummaryWriter()  # tensorboard

    train_loader = DataLoader(SimulatedDataset(opt.train_dir),
                              batch_size=opt.bs,
                              shuffle=True)
    valid_loader = DataLoader(SimulatedDataset(opt.valid_dir),
                              batch_size=opt.bs)

    model = Model().to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

    counter = 0  # epochs since improvement
    best_loss = float("inf")

    print("train_loss\tvalid_loss\tvalid_acc")

    for epoch in range(opt.n_epochs):
        model.train()
        train_loss = 0.0
        size = len(train_loader.dataset)

        # train step
        for i, (anchor_imgs, same_imgs, diff_imgs) in enumerate(train_loader):
            optimizer.zero_grad()
            anchor = model(anchor_imgs.to(device))
            same = model(same_imgs.to(device))
            diff = model(diff_imgs.to(device))

            loss = triplet_loss(anchor, same, diff)
            train_loss += loss.item() * anchor.size(0)
            loss.backward()  # backprop
            optimizer.step()
        train_loss /= size

        # validation step
        model.eval()
        with torch.no_grad():
            valid_loss = 0.0
            total_accuracy = 0.0
            size = len(valid_loader.dataset)
            for i, (anchor_imgs, same_imgs,
                    diff_imgs) in enumerate(valid_loader):
                anchor = model(anchor_imgs.to(device))
                same = model(same_imgs.to(device))
                diff = model(diff_imgs.to(device))

                loss = triplet_loss(anchor, same, diff)
                valid_loss += loss.item() * anchor.size(0)
                total_accuracy += triplet_acc(anchor, same,
                                              diff) * anchor.size(0)
        valid_loss /= size
        total_accuracy /= size

        print(f"{train_loss:.3f}\t{valid_loss:.3f}\t{total_accuracy:.4f}")
        writer.add_scalar("loss/train", train_loss, epoch)
        writer.add_scalar("loss/valid", valid_loss, epoch)
        writer.add_scalar("acc/valid", total_accuracy, epoch)
        writer.flush()

        # early stopping
        if valid_loss < best_loss:
            counter = 0
            print("new best loss, saving checkpoint...")
            torch.save(model.state_dict(), f"models/checkpoint_{epoch}.pth")
            torch.save(model.state_dict(), f"models/weights.pth")
            best_loss = valid_loss
        else:
            counter += 1

        if counter > opt.patience:
            print(f"{opt.patience} epochs without improvement, exiting")
            break
Esempio n. 3
0
def train(dataset_train,
          dataset_val,
          dataset_test,
          ckptfile='',
          caffemodel=''):
    print('Training start...')
    batch_size = FLAGS.batch_size

    path = modelpath("")
    if not os.path.exists(path):
        os.makedirs(path)

    with tf.Graph().as_default():

        startstep = 0  #if not is_finetune else int(ckptfile.split('-')[-1])
        global_step = tf.Variable(startstep, trainable=False)

        # placeholders for graph input

        anchor_search = tf.placeholder('float32', shape=(None, 227, 227, 3))
        anchor_street = tf.placeholder('float32', shape=(None, 227, 227, 3))
        anchor_aerial = tf.placeholder('float32', shape=(None, 227, 227, 3))

        positive = tf.placeholder('float32', shape=(None, 227, 227, 3))
        negative = tf.placeholder('float32', shape=(None, 227, 227, 3))

        keep_prob_ = tf.placeholder('float32')

        # graph outputs
        feature_anchor = model.inference_crossview(
            [anchor_search, anchor_street, anchor_aerial], keep_prob_,
            FLAGS.feature, False)
        feature_positive = model.inference(positive, keep_prob_, FLAGS.feature)
        feature_negative = model.inference(negative, keep_prob_, FLAGS.feature)

        feature_size = tf.size(feature_anchor) / batch_size

        feature_list = model.feature_normalize(
            [feature_anchor, feature_positive, feature_negative])

        loss, d_pos, d_neg, loss_origin = model.triplet_loss(
            feature_list[0], feature_list[1], feature_list[2])

        # summary
        summary_op = tf.merge_all_summaries()

        training_loss = tf.placeholder('float32',
                                       shape=(),
                                       name='training_loss')
        training_summary = tf.scalar_summary('training_loss', training_loss)

        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.lr).minimize(
            loss)  #batch size 512
        #optimizer = tf.train.AdamOptimizer(learning_rate = 0.0000001).minimize(loss)

        #validation
        validation_loss = tf.placeholder('float32',
                                         shape=(),
                                         name='validation_loss')
        validation_summary = tf.scalar_summary('validation_loss',
                                               validation_loss)

        # test

        feature_pair_list = model.feature_normalize(
            [feature_anchor, feature_positive])

        pair_loss = model.eval_loss(feature_pair_list[0], feature_pair_list[1])
        testing_loss = tf.placeholder('float32', shape=(), name='testing_loss')
        testing_summary = tf.scalar_summary('testing_loss', testing_loss)

        init_op = tf.initialize_all_variables()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            saver = tf.train.Saver(max_to_keep=50)
            if ckptfile:
                # load checkpoint file
                saver.restore(sess, ckptfile)
                """
                sess.run(init_op)
                
                all_vars = tf.all_variables()
                cv_vars = [k for k in all_vars if k.name.startswith("cv_")]
                share_vars = [k for k in all_vars if not k.name.startswith("cv_")]

                saver_share = tf.train.Saver(share_vars)

                saver_share.restore(sess, ckptfile)
                with tf.variable_scope('fc6', reuse=True):
                    w = tf.get_variable('weights')
                    b = tf.get_variable('biases')

                with tf.variable_scope('cv_fc6', reuse=True):
                    for subkey, data in zip(('weights', 'biases'), (w, b)):
                        print 'loading cv_fc6', subkey
                        var = tf.get_variable(subkey)
                        sess.run(var.assign(data))

                """
                print 'restore variables done'
            elif caffemodel:
                # load caffemodel generated with caffe-tensorflow
                sess.run(init_op)
                model.load_alexnet(sess, caffemodel)
                print 'loaded pretrained caffemodel:', caffemodel
            else:
                # from scratch
                sess.run(init_op)
                print 'init_op done'

            summary_writer = tf.train.SummaryWriter("logs/{}/{}/{}".format(
                FLAGS.train_dir, FLAGS.feature, parameter_name),
                                                    graph=sess.graph)

            epoch = 1
            global_step = step = print_iter_sum = 0
            min_loss = min_test_loss = sys.maxint
            loss_sum = []

            while True:

                batch_x, batch_y, batch_z, isnextepoch, start, end = dataset_train.sample_path2img(
                    batch_size, True)

                step += len(batch_y)
                global_step += len(batch_y)
                print_iter_sum += len(batch_y)

                feed_dict = {
                    anchor_search: batch_x['search'],
                    anchor_street: batch_x['streetview_clean'],
                    anchor_aerial: batch_x['aerial_clean'],
                    positive: batch_y,
                    negative: batch_z,
                    keep_prob_: 0.5
                }  # dropout rate

                _, loss_value, pos_value, neg_value, origin_value, anchor_value = sess.run(
                    [
                        optimizer, loss, d_pos, d_neg, loss_origin,
                        feature_list[0]
                    ],
                    feed_dict=feed_dict)
                loss_value = np.mean(loss_value)
                loss_sum.append(loss_value)

                if print_iter_sum / print_iter >= 1:
                    loss_sum = np.mean(loss_sum)
                    print('epo{}, {}/{}, loss: {}'.format(
                        epoch, step, len(dataset_train.data), loss_sum))
                    print_iter_sum -= print_iter
                    loss_sum = []

                loss_valuee = sess.run(training_summary,
                                       feed_dict={training_loss: loss_value})

                summary_writer.add_summary(loss_valuee, global_step)
                summary_writer.flush()

                action = 0
                if FLAGS.remove and loss_value == 0:
                    action = dataset_train.remove(start, end)
                    if action == 1:
                        finish_training(saver, sess, epoch)
                        break

                if isnextepoch or action == -1:

                    val_loss_sum = []
                    isnextepoch = False  # set for validation
                    step = 0
                    print_iter_sum = 0

                    # validation
                    while not isnextepoch:

                        val_x, val_y, val_z, isnextepoch, start, end = dataset_val.sample_path2img(
                            batch_size, True)
                        val_feed_dict = {
                            anchor_search: val_x['search'],
                            anchor_street: val_x['streetview_clean'],
                            anchor_aerial: val_x['aerial_clean'],
                            positive: val_y,
                            negative: val_z,
                            keep_prob_: 1.
                        }
                        val_loss = sess.run([loss], feed_dict=val_feed_dict)
                        val_loss_sum.append(np.mean(val_loss))

                    dataset_val.reset_sample()
                    val_loss_sum = np.mean(val_loss_sum)
                    print("Validation loss: {}".format(val_loss_sum))

                    summary_val_loss_sum = sess.run(
                        validation_summary,
                        feed_dict={validation_loss: val_loss_sum})
                    summary_writer.add_summary(summary_val_loss_sum,
                                               global_step)

                    # testing
                    #IPython.embed()
                    num = 50
                    test_feed_dict = {
                        anchor_search: dataset_test[0]['search'][:num],
                        anchor_street:
                        dataset_test[0]['streetview_clean'][:num],
                        anchor_aerial: dataset_test[0]['aerial_clean'][:num],
                        positive: dataset_test[1][:num],
                        negative: dataset_test[0]['search'][:num],  # useless
                        keep_prob_: 1.
                    }

                    test_loss = sess.run([pair_loss], feed_dict=test_feed_dict)
                    test_loss = np.mean(test_loss)
                    print("Testing loss: {}".format(test_loss))
                    summary_test_loss = sess.run(
                        testing_summary, feed_dict={testing_loss: test_loss})
                    summary_writer.add_summary(summary_test_loss, global_step)

                    # ready to flush
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, global_step)
                    summary_writer.flush()
                    # save by testing
                    if min_test_loss > test_loss:
                        min_test_loss = test_loss
                        """
                        if 'best_test_path' in locals():
                            os.remove(best_test_path)
                        """
                        best_test_path = modelpath("test_{}_{}".format(
                            epoch, test_loss))
                        saver.save(sess, best_test_path)
                        print(best_test_path)

                    # save by validation
                    elif min_loss > val_loss_sum:
                        min_loss = val_loss_sum
                        """
                        if 'best_path' in locals():
                            os.remove(best_path)
                        """
                        best_path = modelpath("val_{}_{}".format(
                            epoch, val_loss_sum))
                        saver.save(sess, best_path)
                        print(best_path)
                    # save by SAVE_INTERVAL
                    elif epoch % SAVE_INTERVAL == 0:
                        path = modelpath(epoch)
                        saver.save(sess, path)
                        print(path)

                    dataset_train.reset_sample()
                    print(epoch)

                    epoch += 1
                    if epoch >= max_epo:
                        finish_training(saver, sess, epoch)
                        break
Esempio n. 4
0
def train(dataset_train, dataset_val, ckptfile='', caffemodel=''):
    print('Training start...')
    is_finetune = bool(ckptfile)
    batch_size = FLAGS.batch_size

    path = modelpath("")
    if not os.path.exists(path):
        os.makedirs(path)

    with tf.Graph().as_default():

        startstep = 0  #if not is_finetune else int(ckptfile.split('-')[-1])
        global_step = tf.Variable(startstep, trainable=False)

        # placeholders for graph input

        anchor = tf.placeholder('float32', shape=(None, 227, 227, 3))
        positive = tf.placeholder('float32', shape=(None, 227, 227, 3))
        negative = tf.placeholder('float32', shape=(None, 227, 227, 3))

        keep_prob_ = tf.placeholder('float32')

        # graph outputs
        feature_anchor = model.inference(anchor, keep_prob_, FLAGS.feature,
                                         False)
        feature_positive = model.inference(positive, keep_prob_, FLAGS.feature)
        feature_negative = model.inference(negative, keep_prob_, FLAGS.feature)

        feature_size = tf.size(feature_anchor) / batch_size

        feature_list = model.feature_normalize(
            [feature_anchor, feature_positive, feature_negative])

        loss, d_pos, d_neg, loss_origin = model.triplet_loss(
            feature_list[0], feature_list[1], feature_list[2])

        # summary
        summary_op = tf.merge_all_summaries()

        training_loss = tf.placeholder('float32',
                                       shape=(),
                                       name='training_loss')
        training_summary = tf.scalar_summary('training_loss', training_loss)

        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.lr).minimize(
            loss)  #batch size 512

        #validation
        validation_loss = tf.placeholder('float32',
                                         shape=(),
                                         name='validation_loss')
        validation_summary = tf.scalar_summary('validation_loss',
                                               validation_loss)

        init_op = tf.initialize_all_variables()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:

            saver = tf.train.Saver(max_to_keep=max_epo)
            if ckptfile:
                # load checkpoint file
                saver.restore(sess, ckptfile)
                print('restore variables done')
            elif caffemodel:
                # load caffemodel generated with caffe-tensorflow
                sess.run(init_op)
                model.load_alexnet(sess, caffemodel)
                print('loaded pretrained caffemodel:{}'.format(caffemodel))
            else:
                # from scratch
                sess.run(init_op)
                print('init_op done')

            summary_writer = tf.train.SummaryWriter("logs/{}/{}/{}".format(
                FLAGS.train_dir, FLAGS.feature, parameter_name),
                                                    graph=sess.graph)

            epoch = 1
            global_step = step = print_iter_sum = 0
            min_loss = min_test_loss = sys.maxint
            loss_sum = []

            while True:

                batch_x, batch_y, batch_z, isnextepoch, start, end = dataset_train.sample_path2img(
                    batch_size)

                step += len(batch_x)
                global_step += len(batch_x)
                print_iter_sum += len(batch_x)

                feed_dict = {
                    anchor: batch_x,
                    positive: batch_y,
                    negative: batch_z,
                    keep_prob_: FLAGS.dropout
                }  # dropout rate

                _, loss_value, pos_value, neg_value, origin_value, anchor_value = sess.run(
                    [
                        optimizer, loss, d_pos, d_neg, loss_origin,
                        feature_list[0]
                    ],
                    feed_dict=feed_dict)
                loss_value = np.mean(loss_value)
                loss_sum.append(loss_value)

                if print_iter_sum / print_iter >= 1:
                    loss_sum = np.mean(loss_sum)
                    print('epo{}, {}/{}, loss: {}'.format(
                        epoch, step, len(dataset_train.data), loss_sum))
                    print_iter_sum -= print_iter
                    loss_sum = []

                loss_valuee = sess.run(training_summary,
                                       feed_dict={training_loss: loss_value})

                summary_writer.add_summary(loss_valuee, global_step)
                summary_writer.flush()

                action = 0
                if FLAGS.remove and loss_value == 0:
                    action = dataset_train.remove(start, end)
                    if action == 1:
                        finish_training(saver, sess, epoch)
                        break

                if isnextepoch or action == -1:

                    val_loss_sum = []
                    isnextepoch = False  # set for validation
                    step = 0
                    print_iter_sum = 0

                    # validation
                    while not isnextepoch:

                        val_x, val_y, val_z, isnextepoch, start, end = dataset_val.sample_path2img(
                            batch_size)
                        val_feed_dict = {
                            anchor: val_x,
                            positive: val_y,
                            negative: val_z,
                            keep_prob_: 1.
                        }
                        val_loss = sess.run([loss], feed_dict=val_feed_dict)
                        val_loss_sum.append(np.mean(val_loss))

                    dataset_val.reset_sample()
                    val_loss_sum = np.mean(val_loss_sum)
                    print("Validation loss: {}".format(val_loss_sum))

                    summary_val_loss_sum = sess.run(
                        validation_summary,
                        feed_dict={validation_loss: val_loss_sum})
                    summary_writer.add_summary(summary_val_loss_sum,
                                               global_step)

                    # ready to flush
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, global_step)
                    summary_writer.flush()

                    # save by validation
                    if min_loss > val_loss_sum:
                        min_loss = val_loss_sum
                        best_path = modelpath("val_{}_{}".format(
                            epoch, val_loss_sum))
                        saver.save(sess, best_path)
                        print(best_path)

                    # save by SAVE_INTERVAL
                    elif epoch % SAVE_INTERVAL == 0:
                        path = modelpath(epoch)
                        saver.save(sess, path)
                        print(path)

                    dataset_train.reset_sample()
                    print(epoch)

                    epoch += 1
                    if epoch >= max_epo:
                        finish_training(saver, sess, epoch)
                        break