Ejemplo n.º 1
0
def main():
    with tf.device('/gpu:' + str(GPU_ID)):
        # dis_param = cPickle.load(open(DIS_MODEL_NEWEST_FILE))
        # discriminator = DIS(IMAGE_DIM, TEXT_DIM, HIDDEN_DIM, OUTPUT_DIM, WEIGHT_DECAY, D_LEARNING_RATE, BETA, GAMMA, loss = 'svm', param = dis_param)
        discriminator = DIS(IMAGE_DIM,
                            TEXT_DIM,
                            HIDDEN_DIM,
                            OUTPUT_DIM,
                            WEIGHT_DECAY,
                            D_LEARNING_RATE,
                            BETA,
                            GAMMA,
                            loss='svm',
                            param=None)

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        sess.run(tf.initialize_all_variables())

        print('start adversarial training')
        map_best_val_gen = 0.0
        map_best_val_dis = 0.0
        average_map = 0

        #for epoch in range(WHOLE_EPOCH):
        print('Training D ...')
        for d_epoch in range(D_EPOCH):
            print('d_epoch: ' + str(d_epoch))
            if d_epoch % GS_EPOCH == 0:
                print('negative text sampling for d using g ...')
                dis_train_i2t_list = generate_samples(train_i2t_pos,
                                                      train_i2t_neg, 'i2t')
                print('negative image sampling for d using g ...')
                dis_train_t2i_list = generate_samples(train_t2i_pos,
                                                      train_t2i_neg, 't2i')

            discriminator = train_discriminator(sess, discriminator,
                                                dis_train_i2t_list, 'i2t')
            discriminator = train_discriminator(sess, discriminator,
                                                dis_train_t2i_list, 't2i')

            if (d_epoch + 1) % (D_DISPLAY) == 0:
                i2t_test_map, t2i_test_map, i2i_test_map, t2t_test_map = MAP(
                    sess, discriminator)
                print('I2T_Test_MAP: %.4f' % i2t_test_map)
                print('T2I_Test_MAP: %.4f' % t2i_test_map)
                # print('I2I_Test_MAP: %.4f' % i2i_test_map)
                # print('T2T_Test_MAP: %.4f' % t2t_test_map)

                average_map = 0.5 * (i2t_test_map + t2i_test_map)
                if average_map > map_best_val_dis:
                    map_best_val_dis = average_map
                    discriminator.save_model(sess, DIS_MODEL_BEST_FILE)

            #discriminator.save_model(sess, DIS_MODEL_NEWEST_FILE)

        sess.close()
Ejemplo n.º 2
0
def test():
    discriminator_param = cPickle.load(open(DIS_MODEL_BEST_FILE))
    discriminator = DIS(IMAGE_DIM,
                        TEXT_DIM,
                        HIDDEN_DIM,
                        OUTPUT_DIM,
                        WEIGHT_DECAY,
                        D_LEARNING_RATE,
                        BETA,
                        GAMMA,
                        loss='svm',
                        param=discriminator_param)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.initialize_all_variables())

    I_db = extract_feature(sess, discriminator, database_img)
    T_db = extract_feature(sess, discriminator, database_txt)
    I_te = extract_feature(sess, discriminator, test_img)
    T_te = extract_feature(sess, discriminator, test_txt)
    sio.savemat('./result/DIS_mir_' + str(OUTPUT_DIM) + '.mat', {
        'B_I_db': I_db,
        'B_T_db': T_db,
        'B_I_te': I_te,
        'B_T_te': T_te
    })
    sess.close()
Ejemplo n.º 3
0
def main():
    discriminator = DIS(I_DIM, T_DIM, A_DIM, V_DIM, D_DIM, HIDDEN_DIM,
                        OUTPUT_DIM, WEIGHT_DECAY, D_LEARNING_RATE, BETA, GAMMA)
    #generator = GEN(I_DIM, T_DIM, A_DIM, V_DIM, D_DIM, HIDDEN_DIM, OUTPUT_DIM, WEIGHT_DECAY, D_LEARNING_RATE, BETA, GAMMA)

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.initialize_all_variables())

    # pdb.set_trace()

    saver = tf.train.Saver(var_list=[var for var in tf.trainable_variables()])
    # saver.restore(sess, DIS_MODEL_NEWEST_FILE)

    print('start adversarial training')
    #map_best_val_gen = 0.0
    map_best_val_dis = 0.0

    for epoch in range(WHOLE_EPOCH):
        print('Training D ...')
        for d_epoch in range(D_EPOCH):
            print('d_epoch: ' + str(d_epoch))
            if d_epoch == 0:
                print('negative sampling for d using g ...')
                dis_train_list_I = generate_samples(0)
                dis_train_list_T = generate_samples(1)
                dis_train_list_D = generate_samples(2)
                dis_train_list_V = generate_samples(3)
                dis_train_list_A = generate_samples(4)

            discriminator = train_discriminator(sess, discriminator,
                                                dis_train_list_I, 'I')
            discriminator = train_discriminator(sess, discriminator,
                                                dis_train_list_T, 'T')
            discriminator = train_discriminator(sess, discriminator,
                                                dis_train_list_D, 'D')
            discriminator = train_discriminator(sess, discriminator,
                                                dis_train_list_V, 'V')
            discriminator = train_discriminator(sess, discriminator,
                                                dis_train_list_A, 'A')

            if (d_epoch + 1) % (D_DISPLAY) == 0:
                test_map = MAP_ARGV(sess, discriminator, test_feature,
                                    database_feature, test_label,
                                    database_label, OUTPUT_DIM)
                print('Test_MAP: %.4f' % test_map)
                if test_map > map_best_val_dis:
                    map_best_val_dis = test_map
                    saver.save(sess, DIS_MODEL_BEST_FILE)
            saver.save(sess, DIS_MODEL_NEWEST_FILE)

    sess.close()
Ejemplo n.º 4
0
def test():	
	discriminator_param = pickle.load(open(DIS_MODEL_BEST_I2I_FILE, 'rb'))
	discriminator = DIS(IMAGE_DIM, TEXT_DIM, HIDDEN_DIM, OUTPUT_DIM, WEIGHT_DECAY, D_LEARNING_RATE, BETA, GAMMA, loss ='svm', param=discriminator_param)
	
	config = tf.ConfigProto(allow_soft_placement=True)
	config.gpu_options.allow_growth = True
	sess = tf.Session(config=config)
	sess.run(tf.initialize_all_variables())
	
	I_db = extract_feature(sess, discriminator, train_img, 'image')
	T_db = extract_feature(sess, discriminator, train_txt, 'text')

	knn_img, knn_txt = get_knn(I_db, T_db)

	pdb.set_trace()
	result_dir = '/home/huhengtong/UKD/data/'
	np.save(result_dir + 'teacher_KNN_img.npy', knn_img)
	np.save(result_dir + 'teacher_KNN_txt.npy', knn_txt)
Ejemplo n.º 5
0
def main():
    with tf.device('/gpu:' + str(GPU_ID)):
        dis_param = pickle.load(open(DIS_MODEL_PRETRAIN_FILE, 'rb'))
        # gen_param = cPickle.load(open(GEN_MODEL_PRETRAIN_FILE))
        discriminator = DIS(IMAGE_DIM,
                            TEXT_DIM,
                            HIDDEN_DIM,
                            OUTPUT_DIM,
                            WEIGHT_DECAY,
                            D_LEARNING_RATE,
                            BETA,
                            GAMMA,
                            param=dis_param)
        generator = GEN(IMAGE_DIM,
                        TEXT_DIM,
                        HIDDEN_DIM,
                        OUTPUT_DIM,
                        CLASS_DIM,
                        WEIGHT_DECAY,
                        G_LEARNING_RATE,
                        param=None)

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        sess.run(tf.initialize_all_variables())

        print('start adversarial training')
        map_best_val_gen = 0.0
        map_best_val_dis = 0.0

        for epoch in range(WHOLE_EPOCH):
            print('Training D ...')
            for d_epoch in range(D_EPOCH):
                print('d_epoch: ' + str(d_epoch))
                if d_epoch % GS_EPOCH == 0:
                    print('negative text sampling for d using g ...')
                    dis_train_i2t_list = generate_samples(
                        sess, generator, train_i2t, train_i2t_pos,
                        train_i2t_neg, 'i2t')
                    print('negative image sampling for d using g ...')
                    dis_train_t2i_list = generate_samples(
                        sess, generator, train_t2i, train_t2i_pos,
                        train_t2i_neg, 't2i')
# 					dis_train_i2i_list = generate_samples_pretrain(train_i2i_pos, train_i2i_neg, 'i2i')
# 					dis_train_t2t_list = generate_samples_pretrain(train_t2t_pos, train_t2t_neg, 't2i')

                discriminator = train_discriminator(sess, discriminator,
                                                    dis_train_i2t_list, 'i2t')
                discriminator = train_discriminator(sess, discriminator,
                                                    dis_train_t2i_list, 't2i')
                # 				discriminator = train_discriminator(sess, discriminator, dis_train_i2i_list, 'i2i')
                # 				discriminator = train_discriminator(sess, discriminator, dis_train_t2t_list, 't2t')
                if (d_epoch + 1) % (D_DISPLAY) == 0:
                    i2t_test_map, t2i_test_map = MAP(sess, discriminator)
                    print(
                        '---------------------------------------------------------------'
                    )
                    print('train_I2T_Test_MAP: %.4f' % i2t_test_map)
                    print('train_T2I_Test_MAP: %.4f' % t2i_test_map)


# 					average_map = 0.5 * (i2t_test_map + t2i_test_map)
# 					if average_map > map_best_val_dis:
# 						map_best_val_dis = average_map
# 						discriminator.save_model(sess, DIS_MODEL_BEST_FILE)
#				discriminator.save_model(sess, DIS_MODEL_NEWEST_FILE)

            print('Training G ...')
            for g_epoch in range(G_EPOCH):
                print('g_epoch: ' + str(g_epoch))
                generator = train_generator(sess, generator, discriminator,
                                            train_i2t, train_i2t_pos, 'i2t')
                generator = train_generator(sess, generator, discriminator,
                                            train_t2i, train_t2i_pos, 't2i')

                if (g_epoch + 1) % (G_DISPLAY) == 0:
                    i2t_test_map, t2i_test_map = MAP(sess, generator,
                                                     test_i2t_pos, test_i2t,
                                                     test_t2i_pos, test_t2i,
                                                     feature_dict, label_dict)
                    print(
                        '---------------------------------------------------------------'
                    )
                    print('train_I2T_Test_MAP: %.4f' % i2t_test_map)
                    print('train_T2I_Test_MAP: %.4f' % t2i_test_map)

        sess.close()