Exemplo n.º 1
0
def main():
    with tf.device('/gpu:' + str(GPU_ID)):
        # dis_param = cPickle.load(open(DIS_MODEL_NEWEST_FILE))
        # discriminator = DIS(IMAGE_DIM, TEXT_DIM, HIDDEN_DIM, OUTPUT_DIM, WEIGHT_DECAY, D_LEARNING_RATE, BETA, GAMMA, loss = 'svm', param = dis_param)
        discriminator = DIS(IMAGE_DIM,
                            TEXT_DIM,
                            HIDDEN_DIM,
                            OUTPUT_DIM,
                            WEIGHT_DECAY,
                            D_LEARNING_RATE,
                            BETA,
                            GAMMA,
                            loss='svm',
                            param=None)

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        sess.run(tf.initialize_all_variables())

        print('start adversarial training')
        map_best_val_gen = 0.0
        map_best_val_dis = 0.0
        average_map = 0

        #for epoch in range(WHOLE_EPOCH):
        print('Training D ...')
        for d_epoch in range(D_EPOCH):
            print('d_epoch: ' + str(d_epoch))
            if d_epoch % GS_EPOCH == 0:
                print('negative text sampling for d using g ...')
                dis_train_i2t_list = generate_samples(train_i2t_pos,
                                                      train_i2t_neg, 'i2t')
                print('negative image sampling for d using g ...')
                dis_train_t2i_list = generate_samples(train_t2i_pos,
                                                      train_t2i_neg, 't2i')

            discriminator = train_discriminator(sess, discriminator,
                                                dis_train_i2t_list, 'i2t')
            discriminator = train_discriminator(sess, discriminator,
                                                dis_train_t2i_list, 't2i')

            if (d_epoch + 1) % (D_DISPLAY) == 0:
                i2t_test_map, t2i_test_map, i2i_test_map, t2t_test_map = MAP(
                    sess, discriminator)
                print('I2T_Test_MAP: %.4f' % i2t_test_map)
                print('T2I_Test_MAP: %.4f' % t2i_test_map)
                # print('I2I_Test_MAP: %.4f' % i2i_test_map)
                # print('T2T_Test_MAP: %.4f' % t2t_test_map)

                average_map = 0.5 * (i2t_test_map + t2i_test_map)
                if average_map > map_best_val_dis:
                    map_best_val_dis = average_map
                    discriminator.save_model(sess, DIS_MODEL_BEST_FILE)

            #discriminator.save_model(sess, DIS_MODEL_NEWEST_FILE)

        sess.close()
Exemplo n.º 2
0
def main():
    with tf.device('/gpu:' + str(GPU_ID)):
        dis_param = cPickle.load(open(DIS_MODEL_PRETRAIN_FILE))
        # gen_param = cPickle.load(open(GEN_MODEL_PRETRAIN_FILE))
        discriminator = DIS(IMAGE_DIM,
                            TEXT_DIM,
                            HIDDEN_DIM,
                            OUTPUT_DIM,
                            WEIGHT_DECAY,
                            D_LEARNING_RATE,
                            BETA,
                            GAMMA,
                            param=dis_param)
        # generator = GEN(IMAGE_DIM, TEXT_DIM, HIDDEN_DIM, OUTPUT_DIM, CLASS_DIM, WEIGHT_DECAY, G_LEARNING_RATE, param = gen_param)
        # discriminator = DIS(IMAGE_DIM, TEXT_DIM, HIDDEN_DIM, OUTPUT_DIM, WEIGHT_DECAY, D_LEARNING_RATE, BETA, GAMMA, param = None)
        generator = GEN(IMAGE_DIM,
                        TEXT_DIM,
                        HIDDEN_DIM,
                        OUTPUT_DIM,
                        CLASS_DIM,
                        WEIGHT_DECAY,
                        G_LEARNING_RATE,
                        param=None)

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        sess.run(tf.initialize_all_variables())

        print('start adversarial training')
        map_best_val_gen = 0.0
        map_best_val_dis = 0.0

        for epoch in range(WHOLE_EPOCH):
            print('Training D ...')
            for d_epoch in range(D_EPOCH):
                print('d_epoch: ' + str(d_epoch))
                if d_epoch % GS_EPOCH == 0:
                    print('negative text sampling for d using g ...')
                    dis_train_i2t_list = generate_samples(
                        sess, generator, train_i2t, train_i2t_pos,
                        train_i2t_neg, 'i2t')
                    print('negative image sampling for d using g ...')
                    dis_train_t2i_list = generate_samples(
                        sess, generator, train_t2i, train_t2i_pos,
                        train_t2i_neg, 't2i')

                discriminator = train_discriminator(sess, discriminator,
                                                    dis_train_i2t_list, 'i2t')
                discriminator = train_discriminator(sess, discriminator,
                                                    dis_train_t2i_list, 't2i')

                if (d_epoch + 1) % (D_DISPLAY) == 0:
                    i2t_test_map = MAP(sess, discriminator, test_i2t_pos,
                                       test_i2t, feature_dict, 'i2t')
                    print('E%d D%d I2T_Test_MAP: %.4f' %
                          (epoch, d_epoch, i2t_test_map))
                    t2i_test_map = MAP(sess, discriminator, test_t2i_pos,
                                       test_t2i, feature_dict, 't2i')
                    print('E%d D%d T2I_Test_MAP: %.4f' %
                          (epoch, d_epoch, t2i_test_map))

                    with open('record.txt', 'a') as record_file:
                        record_file.write('E%d D%d I2T_Test_MAP: %.4f\n' %
                                          (epoch, d_epoch, i2t_test_map))
                        record_file.write('E%d D%d T2I_Test_MAP: %.4f\n' %
                                          (epoch, d_epoch, t2i_test_map))

                    average_map = 0.5 * (i2t_test_map + t2i_test_map)
                    if average_map > map_best_val_dis:
                        map_best_val_dis = average_map
                        discriminator.save_model(sess, DIS_MODEL_BEST_FILE)
                discriminator.save_model(sess, DIS_MODEL_NEWEST_FILE)

            print('Training G ...')
            for g_epoch in range(G_EPOCH):
                print('g_epoch: ' + str(g_epoch))
                generator = train_generator(sess, generator, discriminator,
                                            train_i2t, train_i2t_pos, 'i2t')
                generator = train_generator(sess, generator, discriminator,
                                            train_t2i, train_t2i_pos, 't2i')

                if (g_epoch + 1) % (G_DISPLAY) == 0:
                    i2t_test_map = MAP(sess, generator, test_i2t_pos, test_i2t,
                                       feature_dict, 'i2t')
                    print('E%d G%d I2T_Test_MAP: %.4f' %
                          (epoch, g_epoch, i2t_test_map))
                    t2i_test_map = MAP(sess, generator, test_t2i_pos, test_t2i,
                                       feature_dict, 't2i')
                    print('E%d G%d T2I_Test_MAP: %.4f' %
                          (epoch, g_epoch, t2i_test_map))

                    with open('record.txt', 'a') as record_file:
                        record_file.write('E%d G%d I2T_Test_MAP: %.4f\n' %
                                          (epoch, g_epoch, i2t_test_map))
                        record_file.write('E%d G%d T2I_Test_MAP: %.4f\n' %
                                          (epoch, g_epoch, t2i_test_map))

                    average_map = 0.5 * (i2t_test_map + t2i_test_map)
                    if average_map > map_best_val_gen:
                        map_best_val_gen = average_map
                        generator.save_model(sess, GEN_MODEL_BEST_FILE)
                generator.save_model(sess, GEN_MODEL_NEWEST_FILE)

        sess.close()