Exemple #1
0
def train(start_epoch=1):
    '''
    Train the network and do the test
    :param start_epoch: the epoch id start to train. The first epoch is 1.
    '''

    # import data
    input_data = InputData()

    # define placeholders
    sat_x = tf.placeholder(tf.float32, [None, 512, 512, 3], name='sat_x')
    grd_x = tf.placeholder(tf.float32, [None, 224, 1232, 3], name='grd_x')
    keep_prob = tf.placeholder(tf.float32)
    learning_rate = tf.placeholder(tf.float32)

    # build model
    if network_type == 'CVM-NET-I':
        sat_global, grd_global = cvm_net_I(sat_x, grd_x, keep_prob,
                                           is_training)
    elif network_type == 'CVM-NET-II':
        sat_global, grd_global = cvm_net_II(sat_x, grd_x, keep_prob,
                                            is_training)
    else:
        print(
            'CONFIG ERROR: wrong network type, only CVM-NET-I and CVM-NET-II are valid'
        )

    # define loss
    loss = compute_loss(sat_global, grd_global, 0)

    # set training
    global_step = tf.Variable(0, trainable=False)
    with tf.device('/gpu:0'):
        with tf.name_scope('train'):
            train_step = tf.train.AdamOptimizer(
                learning_rate, 0.9, 0.999).minimize(loss,
                                                    global_step=global_step)

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

    # run model
    print('run model...')
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    #config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = 0.9
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        print('load model...')
        load_model_path = '../Model/' + network_type + '/' + str(
            start_epoch - 1) + '/model.ckpt'
        saver.restore(sess, load_model_path)
        print("   Model loaded from: %s" % load_model_path)
        print('load model...FINISHED')

        # Train
        for epoch in range(start_epoch, start_epoch + number_of_epoch):
            iter = 0
            if is_training:
                # train
                while True:
                    batch_sat, batch_grd = input_data.next_pair_batch(
                        batch_size)
                    if batch_sat is None:
                        break

                    global_step_val = tf.train.global_step(sess, global_step)

                    feed_dict = {
                        sat_x: batch_sat,
                        grd_x: batch_grd,
                        learning_rate: learning_rate_val,
                        keep_prob: keep_prob_val
                    }
                    if iter % 20 == 0:
                        _, loss_val = sess.run([train_step, loss],
                                               feed_dict=feed_dict)
                        print('global %d, epoch %d, iter %d: loss : %.4f' %
                              (global_step_val, epoch, iter, loss_val))
                    else:
                        sess.run(train_step, feed_dict=feed_dict)

                    iter += 1

            # ---------------------- validation ----------------------
            print('validate...')
            print('   compute global descriptors')
            input_data.reset_scan()
            sat_global_descriptor = np.zeros(
                [input_data.get_test_dataset_size(), 4096])
            grd_global_descriptor = np.zeros(
                [input_data.get_test_dataset_size(), 4096])
            val_i = 0
            while True:
                print('      progress %d' % val_i)
                batch_sat, batch_grd = input_data.next_batch_scan(batch_size)
                if batch_sat is None:
                    break
                feed_dict = {
                    sat_x: batch_sat,
                    grd_x: batch_grd,
                    keep_prob: 1.0
                }
                sat_global_val, grd_global_val = \
                    sess.run([sat_global, grd_global], feed_dict=feed_dict)

                sat_global_descriptor[
                    val_i:val_i + sat_global_val.shape[0], :] = sat_global_val
                grd_global_descriptor[
                    val_i:val_i + grd_global_val.shape[0], :] = grd_global_val
                val_i += sat_global_val.shape[0]

            print('   compute accuracy')
            val_accuracy = validate(grd_global_descriptor,
                                    sat_global_descriptor)
            with open('../Result/' + str(network_type) + '_accuracy.txt',
                      'a') as file:
                file.write(
                    str(epoch) + ' ' + str(iter) + ' : ' + str(val_accuracy) +
                    '\n')
            print('   %d: accuracy = %.1f%%' % (epoch, val_accuracy * 100.0))

            model_dir = '../Model/' + network_type + '/' + str(epoch) + '/'
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            save_path = saver.save(sess, model_dir + 'model.ckpt')
            print("Model saved in file: %s" % save_path)
Exemple #2
0
def train(start_epoch=1):
    '''
    Train the network and do the test
    :param start_epoch: the epoch id start to train. The first epoch is 1.
    '''

    # get variable list of pretrained model: source of pretrained weights
    #    CHECKPOINT_NAME = '../../baseline_cvusa/Model/my_net_ms/model.ckpt'
    #    restored_vars = get_tensors_in_checkpoint_file(file_name=CHECKPOINT_NAME)

    # import data
    input_data = InputData()

    # define placeholders
    sat_x = tf.placeholder(tf.float32, [None, 512, 512, 3], name='sat_x')
    grd_x = tf.placeholder(tf.float32, [None, 224, 1232, 3], name='grd_x')
    sat_x_synth = tf.placeholder(tf.float32, [None, 512, 512, 3],
                                 name='sat_x_synth')

    keep_prob = tf.placeholder(tf.float32)
    learning_rate = tf.placeholder(tf.float32)

    # build model: three_stream_with_gan_imgs
    if network_type == 'joint_feat_learning':
        sat_global, grd_global, gan_sat_global = joint_feat_learning(
            sat_x, grd_x, sat_x_synth, keep_prob, is_training)
    else:
        print(
            'CONFIG ERROR: wrong network type, only joint_feat_learning valid')

    # define loss
    loss1 = compute_loss(sat_global, grd_global, 0)
    loss2 = compute_loss(gan_sat_global, sat_global, 0)

    #    loss = (loss1 + 10 * loss2)/(1 + 10)
    loss = (10 * loss1 + loss2) / (11)

    # set training
    global_step = tf.Variable(0, trainable=False)
    with tf.device('/gpu:0'):
        with tf.name_scope('train'):
            train_step = tf.train.AdamOptimizer(
                learning_rate, 0.9, 0.999).minimize(loss,
                                                    global_step=global_step)

    saver_full = tf.train.Saver(tf.global_variables(), max_to_keep=None)

    # run model
    print('run model...')
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.9
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()

        ##        # loading pretrained model and copying its weights: restoring the pretrained wts from two stream baseline.
        #        var_list = tf.global_variables()
        #        var_list = [x for x in var_list if str(x.name)[:-2] in restored_vars[0]]
        #        saver = tf.train.Saver(var_list, max_to_keep=None)
        #        print('load pretrained model...')
        #        saver.restore(sess, CHECKPOINT_NAME)
        #        print(" Pretrained model loaded from: %s" % CHECKPOINT_NAME)
        #        print('load model...FINISHED')

        # # # load model from intermediate epoch of joint feature learning experiment
        #        load_model_path = '../Model/' + network_type + '/model.ckpt'
        #        saver_full.restore(sess, load_model_path)
        #        print("   Model loaded from: %s" % load_model_path)
        #        print('load model...FINISHED')

        # Train
        best_accuracy = 0.0
        for epoch in range(start_epoch, start_epoch + number_of_epoch):
            iter = 0
            while True:
                # train
                batch_sat, batch_grd = input_data.next_pair_batch(batch_size)

                if batch_sat is None:
                    break

                global_step_val = tf.train.global_step(sess, global_step)

                feed_dict = {
                    sat_x: batch_sat[:, :, :, :3],
                    grd_x: batch_grd,
                    sat_x_synth: batch_sat[:, :, :, 3:],
                    learning_rate: learning_rate_val,
                    keep_prob: keep_prob_val
                }
                if iter % 50 == 0:
                    _, loss_val = sess.run([train_step, loss],
                                           feed_dict=feed_dict)
                    print('global %d, epoch %d, iter %d: loss : %.4f' %
                          (global_step_val, epoch, iter, loss_val))
                else:
                    sess.run(train_step, feed_dict=feed_dict)

                iter += 1

            # ---------------------- validation ----------------------
            print('validate...')
            print('   compute global descriptors')
            input_data.reset_scan()
            sat_global_descriptor = np.zeros(
                [input_data.get_test_dataset_size(), 1000])
            gan_sat_global_descriptor = np.zeros(
                [input_data.get_test_dataset_size(), 1000])
            grd_global_descriptor = np.zeros(
                [input_data.get_test_dataset_size(), 1000])
            val_i = 0
            while True:
                if (val_i % 2000 == 0):
                    print('      progress %d' % val_i)

                batch_sat, batch_grd = input_data.next_batch_scan(batch_size)
                if batch_sat is None:
                    break
                feed_dict = {
                    sat_x: batch_sat[:, :, :, :3],
                    grd_x: batch_grd,
                    sat_x_synth: batch_sat[:, :, :, 3:],
                    keep_prob: 1.0
                }
                sat_global_val, grd_global_val, gan_sat_global_val = \
                    sess.run([sat_global, grd_global, gan_sat_global], feed_dict=feed_dict)

                gan_sat_global_descriptor[
                    val_i:val_i +
                    gan_sat_global_val.shape[0], :] = gan_sat_global_val
                sat_global_descriptor[
                    val_i:val_i + sat_global_val.shape[0], :] = sat_global_val
                grd_global_descriptor[
                    val_i:val_i + grd_global_val.shape[0], :] = grd_global_val
                val_i += sat_global_val.shape[0]

            print('   compute gan+aerial accuracy')
            val_accuracy1 = validate(gan_sat_global_descriptor,
                                     sat_global_descriptor)
            print('   %d: accuracy = %.2f%%' % (epoch, val_accuracy1 * 100.0))

            print('   compute real+aerial accuracy')
            val_accuracy = validate(grd_global_descriptor,
                                    sat_global_descriptor)
            print('   %d: accuracy = %.2f%%' % (epoch, val_accuracy * 100.0))
            #            exit()
            with open('../Result/' + str(network_type) + '_accuracy.txt',
                      'a') as file:
                file.write(
                    str(epoch) + ' ' + str(iter) + ' : ' + str(val_accuracy) +
                    '\n')
            print('   %d: accuracy = %.2f%%' % (epoch, val_accuracy * 100.0))

            model_dir = '../Model/' + network_type + '/'

            if (best_accuracy < val_accuracy):
                best_accuracy = val_accuracy
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                save_path = saver_full.save(sess, model_dir + 'model.ckpt')
                print("Model saved in file: %s" % save_path)
                sio.savemat(
                    str(network_type) + '.mat',
                    dict([('grd_feats', grd_global_descriptor),
                          ('sat_feats', sat_global_descriptor),
                          ('gan_sat_feats', gan_sat_global_descriptor)]))

            else:
                print("Model not saved for epoch:" + str(epoch))
Exemple #3
0
def train(start_epoch=1):
    '''
    Train the network and do the test
    :param start_epoch: the epoch id start to train. The first epoch is 1.
    '''

    # import data (get the train and validation data) in the format
    # satellite filename, streetview filename, pano_id
    # its job is to just create a python version of the list that's already
    # there in the test file
    input_data = InputData()

    # define placeholders to feed actual training examples
    # size of the actual images sat-ellite and ground
    #satellite (512, 512) image shape
    sat_x = tf.placeholder(tf.float32, [None, 512, 512, 3], name='sat_x')
    #ground (224, 1232) image shape
    grd_x = tf.placeholder(tf.float32, [None, 224, 1232, 3], name='grd_x')
    keep_prob = tf.placeholder(tf.float32)  #dropout
    learning_rate = tf.placeholder(tf.float32)

    # just BUILDING MODEL, satellite and ground image will be given later
    if network_type == 'CVM-NET-I':
        sat_global, grd_global = cvm_net_I(sat_x, grd_x, keep_prob,
                                           is_training)
    elif network_type == 'CVM-NET-II':
        sat_global, grd_global = cvm_net_II(sat_x, grd_x, keep_prob,
                                            is_training)
    else:
        print(
            'CONFIG ERROR: wrong network type, only CVM-NET-I and CVM-NET-II are valid'
        )

    # define loss
    loss = compute_loss(sat_global, grd_global, 0)

    # set training
    global_step = tf.Variable(0, trainable=False)
    with tf.device('/gpu:0'):
        with tf.name_scope('train'):
            train_step = tf.train.AdamOptimizer(
                learning_rate, 0.9, 0.999).minimize(loss,
                                                    global_step=global_step)

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

    # run model
    print('run model...')
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.9
    with tf.Session(config=config) as sess:

        sess.run(tf.global_variables_initializer())

        print('load model...')
        # load_model_path = '../Model/' + network_type + '/' + str(start_epoch - 1) + '/model.ckpt'
        # saver.restore(sess, load_model_path)
        # print("   Model loaded from: %s" % load_model_path)
        # print('load model...FINISHED')

        os.chdir('../../Model/')

        cwd = os.getcwd()
        load_model_path = cwd + '/' + network_name + '/' + network_name + '_model'
        print(load_model_path)
        saver = tf.train.import_meta_graph(load_model_path +
                                           "/model.ckpt.meta")
        print('????????')
        load_model_path += '/model.ckpt'
        saver.restore(sess, load_model_path)
        print("   Model loaded from: %s" % load_model_path)
        print('load model...FINISHED')
        import tensorflow.contrib.slim as slim
        model_vars = tf.trainable_variables()
        slim.model_analyzer.analyze_vars(model_vars, print_info=True)

        print('training...')

        # Train
        for epoch in range(start_epoch, start_epoch + number_of_epoch):
            iter = 0
            while True:
                # train
                batch_sat, batch_grd = input_data.next_pair_batch(batch_size)
                if batch_sat is None:
                    break

                global_step_val = tf.train.global_step(sess, global_step)

                feed_dict = {
                    sat_x: batch_sat,
                    grd_x: batch_grd,
                    learning_rate: learning_rate_val,
                    keep_prob: keep_prob_val
                }
                print("run model")
                if iter % 20 == 0:
                    print('running {}'.format(iter))
                    _, loss_val = sess.run([train_step, loss],
                                           feed_dict=feed_dict)
                    print('global %d, epoch %d, iter %d: loss : %.4f' %
                          (global_step_val, epoch, iter, loss_val))
                else:
                    print("running")
                    sess.run(train_step, feed_dict=feed_dict)
                print("ran once?")
                iter += 1
Exemple #4
0
def train(start_epoch=1):
    '''
    Train the network and do the test
    :param start_epoch: the epoch id start to train. The first epoch is 1.
    '''

    # import data (get the train and validation data) in the format
    # satellite filename, streetview filename, pano_id
    # its job is to just create a python version of the list that's already
    # there in the test file
    input_data = InputData()

    # define placeholders to feed actual training examples
    # size of the actual images sat-ellite and ground
    # satellite (512, 512) image shape
    sat_x = tf.placeholder(tf.float32, [None, 512, 512, 3], name='sat_x')
    # ground (224, 1232) image shape
    grd_x = tf.placeholder(tf.float32, [None, 224, 1232, 3], name='grd_x')
    keep_prob = tf.placeholder(tf.float32)  # dropout
    learning_rate = tf.placeholder(tf.float32)

    # just BUILDING MODEL, satellite and ground image will be given later
    if network_type == 'CVM-NET-I':
        sat_global, grd_global = cvm_net_I(sat_x, grd_x, keep_prob,
                                           is_training)
    elif network_type == 'CVM-NET-II':
        sat_global, grd_global = cvm_net_II(sat_x, grd_x, keep_prob,
                                            is_training)
    else:
        print(
            'CONFIG ERROR: wrong network type, only CVM-NET-I and CVM-NET-II are valid'
        )

    # define loss
    loss = compute_loss(sat_global, grd_global, 0)

    # set training
    global_step = tf.Variable(0, trainable=False)
    with tf.device('/gpu:0'):
        with tf.name_scope('train'):
            train_step = tf.train.AdamOptimizer(
                learning_rate, 0.9, 0.999).minimize(loss,
                                                    global_step=global_step)

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

    # run model
    print('run model...')
    config = tf.ConfigProto(log_device_placement=False,
                            allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.9

    plt.ion()
    plt.xlabel('epochs')
    plt.ylabel('loss/accuracy')
    plt.show()
    p = []
    p_val = []
    p_acc = []
    p_acc_ = []
    with tf.Session(config=config) as sess:

        sess.run(tf.global_variables_initializer())

        print('load model...')

        ### dont uncomment ###
        # load_model_path = '../Model/' + network_type + '/' + str(start_epoch - 1) + '/model.ckpt'
        # saver.restore(sess, load_model_path)
        # print("   Model loaded from: %s" % load_model_path)
        # print('load model...FINISHED')

        os.chdir('../Model/')
        cwd = os.getcwd()
        if (start_epoch == 1):
            load_model_path = cwd + '/' + network_name + '/' + network_name + '_model'
        # else:
        #     load_model_path = cwd + '/' + network_name + '/' + network_name + '_syd_original/' + network_type + '/' + str(start_epoch)
        saver = tf.train.import_meta_graph(load_model_path +
                                           "/model.ckpt.meta")
        load_model_path += '/model.ckpt'
        saver.restore(sess, load_model_path)
        print("   Model loaded from: %s" % load_model_path)
        print('load model...FINISHED')
        # import tensorflow.contrib.slim as slim
        # model_vars = tf.trainable_variables()
        # slim.model_analyzer.analyze_vars(model_vars, print_info=True)

        print('training...from epoch {}'.format(start_epoch))
        # Train

        for epoch in range(start_epoch, start_epoch + number_of_epoch):
            iter = 0
            train_loss = []
            val_loss = []
            while True:
                # train
                batch_sat, batch_grd = input_data.next_pair_batch(batch_size)
                if batch_sat is None:
                    break

                global_step_val = tf.train.global_step(sess, global_step)

                feed_dict = {
                    sat_x: batch_sat,
                    grd_x: batch_grd,
                    learning_rate: learning_rate_val,
                    keep_prob: keep_prob_val
                }
                print("run model")
                # if iter % 20 == 0:
                # print('running {}'.format(iter))
                _, loss_val = sess.run([train_step, loss], feed_dict=feed_dict)
                train_loss.append(loss_val)
                print('global %d, epoch %d, iter %d: loss : %.4f' %
                      (global_step_val, epoch, iter, loss_val))
                train_loss.append(loss_val)
                iter += 1

            plt.legend()
            p += [np.mean(train_loss)]
            plt.plot(p, 'b-')
            plt.pause(0.05)

            # ---------------------- validation ----------------------
            print('validate...')
            print('   compute global descriptors')
            input_data.reset_scan()
            sat_global_descriptor = np.zeros(
                [input_data.get_test_dataset_size(), 4096])
            grd_global_descriptor = np.zeros(
                [input_data.get_test_dataset_size(), 4096])
            val_i = 0
            while True:
                print('      progress %d' % val_i)
                # get the sat and grd batch; this is just the input images
                batch_sat, batch_grd = input_data.next_batch_scan(batch_size)
                if batch_sat is None:
                    break  # break once all batches are over
                # create a dictionary
                feed_dict = {
                    sat_x: batch_sat,
                    grd_x: batch_grd,
                    keep_prob: 1.0
                }

                # this dictionary stores all the global descriptors
                sat_global_val, grd_global_val = \
                    sess.run([sat_global, grd_global], feed_dict=feed_dict)
                # print('sat_global_val ', sat_global_val)

                val_loss.append(sess.run(loss, feed_dict=feed_dict))

                sat_global_descriptor[
                    val_i:val_i + sat_global_val.shape[0], :] = sat_global_val
                grd_global_descriptor[
                    val_i:val_i + grd_global_val.shape[0], :] = grd_global_val
                val_i += sat_global_val.shape[0]  # is this 64*512?

            # print('val_loss ', val_loss)
            p_val += [np.mean(val_loss)]
            plt.plot(p_val, 'r-')
            plt.pause(0.05)

            print('   compute accuracy')
            val_accuracy, val_accuracy_ = validate(grd_global_descriptor,
                                                   sat_global_descriptor)
            p_acc += [val_accuracy]
            p_acc_ += [val_accuracy_]
            plt.plot(p_acc, 'k-')
            plt.pause(0.05)
            plt.plot(p_acc_, 'g-')
            plt.pause(0.05)

            with open('../Result/' + str(network_type) + '_accuracy.txt',
                      'a') as file:
                file.write(
                    str(epoch) + ' ' + str(iter) + ' : ' + str(val_accuracy) +
                    '\n')
            print('   %d: accuracy = %.1f%%' % (epoch, val_accuracy * 100.0))
            print('accuracy_ ', val_accuracy_)
            cwd = os.getcwd()
            os.chdir('../Model/CVM-Net-I/CVM-Net-I_sydney_dense/')
            cwd = os.getcwd()
            os.chdir('../../../CVM-Net/')

            model_dir = cwd + '/' + network_type + '/' + str(epoch) + '/'
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
                if (epoch > 70 or epoch % 5 == 0):
                    save_path = saver.save(sess, model_dir + 'model.ckpt')
                    # sio.savemat(model_dir + 'np_vector_CVM_Net.mat', {'sat_global_descriptor': sat_global_descriptor,
                    #                                                   'grd_global_descriptor': grd_global_descriptor})
                    print("Model saved in file: %s" % save_path)