예제 #1
0
def main():
    print("load initial model ...")
    #param_old = pickle.load(open(GAN_MODEL_BEST_FILE_OLD, 'rb+'), encoding="iso-8859-1")

    generator = GEN(FEATURE_SIZE,
                    G_WEIGHT_DECAY,
                    G_LEARNING_RATE,
                    LEN_UNSIGNED,
                    param=None)
    print('Gen Done!!!')
    discriminator = DIS(FEATURE_SIZE,
                        D_WEIGHT_DECAY,
                        D_LEARNING_RATE,
                        LEN_GAN,
                        param=None)
    print("DIS Done!!!")

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    print(GAN_MODEL_BEST_FILE)
    G_map_best = 0
    Test_map_best = 0

    for epoch in range(10):
        print("epoch" + str(epoch))

        # 从PRED_SIZE中随机抽取QUERY_TRAIN_SIZE个样本作为query
        random_query_D_feature = []
        random_query_D_label = []
        generated_data = []
        neg_data = []
        for index_query in range(0, QUERY_TRAIN_SIZE):
            if index_query % 10 == 0 or index_query == QUERY_TRAIN_SIZE:
                print("random_query_from_G_for_D " + str(index_query))

            # 随机生成query序号
            query = random.randint(0, TRAIN_SIZE - 1)
            random_query_D_feature.append(query_train_feature[query])
            random_query_D_label.append(query_train_label[query])

            current_query_feature = []
            current_query_feature.append(query_train_feature[query])
            current_query_feature = np.asarray(current_query_feature)

            # 针对每一个query,计算dataset的得分以及根据softmax对dataset排序
            pred_list_score = sess.run(generator.pred_score,
                                       feed_dict={
                                           generator.query_data:
                                           current_query_feature,
                                           generator.pred_data: pred_feature
                                       })

            exp_rating = np.exp(pred_list_score)
            prob = exp_rating / np.sum(exp_rating)
            sortlist = combine(unsigned_list_pred_index, prob)
            sortlist.sort(key=lambda x: x[1], reverse=True)

            # 取排名的前LEN_GAN个加入generated_data  query序号 + dataset图片序号 + dataset特征
            for i in range(0, LEN_GAN):
                generated_data.append((index_query, sortlist[i][0],
                                       pred_feature[int(sortlist[i][0])]))
            for j in range(PRED_SIZE - LEN_GAN, PRED_SIZE):
                neg_data.append((index_query, sortlist[j][0],
                                 pred_feature[int(sortlist[j][0])]))

        # Train D
        print('Training D ...')
        for d_epoch in range(10):
            print('d_epoch' + str(d_epoch))
            for index_query in range(0, QUERY_TRAIN_SIZE):
                #每次获取QUERY_TRAIN_SIZE个query的特征
                input_query = []
                input_query.append(random_query_D_feature[index_query])
                #从generated_data中读取排好序的特征
                input_gan = []
                input_neg = []
                for index_gan in range(0, LEN_GAN):
                    input_gan.append(generated_data[index_query * LEN_GAN +
                                                    index_gan][2])
                for index_gan in range(0, LEN_GAN):
                    input_neg.append(neg_data[index_query * LEN_GAN +
                                              index_gan][2])
                _ = sess.run(discriminator.d_updates,
                             feed_dict={
                                 discriminator.query_data: input_query,
                                 discriminator.gan_data: input_gan,
                                 discriminator.neg_data: input_neg
                             })
            #测试判别器参数好坏
            D_map = MAP(sess, discriminator, random_query_D_feature,
                        random_query_D_label, pred_label, generated_data,
                        QUERY_TRAIN_SIZE, LEN_GAN)
            print("map:", "map_D", D_map)

        # Train G
        print('Training G ...')
        number_index = np.random.permutation(TRAIN_SIZE)
        number = 0
        for g_epoch in range(10):
            print('g_epoch' + str(g_epoch))
            #从PRED_SIZE中随机抽取QUERY_TRAIN_SIZE个样本作为query
            random_query_G_feature = []
            random_query_G_label = []
            generated_data = []
            neg_data = []
            for index_query in range(0, QUERY_TRAIN_SIZE):
                if index_query % 10 == 0 or index_query == QUERY_TRAIN_SIZE:
                    print("random_query_from_G_for_G " + str(index_query))

                # 随机生成query序号
                if number == TRAIN_SIZE - 1:
                    number = 0
                query = number_index[number]
                number = number + 1
                random_query_G_feature.append(query_train_feature[query])
                random_query_G_label.append(query_train_label[query])

                current_query_feature_un = []
                current_query_feature_un.append(query_train_feature[query])
                current_query_feature_un = np.asarray(current_query_feature_un)

                #针对每一个query,计算dataset的得分以及根据softmax对dataset排序
                pred_list_score = sess.run(generator.pred_score,
                                           feed_dict={
                                               generator.query_data:
                                               current_query_feature_un,
                                               generator.pred_data:
                                               pred_feature
                                           })
                exp_rating = np.exp(pred_list_score)
                prob = exp_rating / np.sum(exp_rating)
                sortlist = combine(unsigned_list_pred_index, prob)
                sortlist.sort(key=lambda x: x[1], reverse=True)
                # 取排名的前LEN_GAN个加入generated_data  query序号 + dataset图片序号 + dataset特征
                for i in range(0, LEN_GAN):
                    generated_data.append((index_query, sortlist[i][0],
                                           pred_feature[int(sortlist[i][0])]))
                for j in range(PRED_SIZE - LEN_GAN, PRED_SIZE):
                    neg_data.append((index_query, sortlist[j][0],
                                     pred_feature[int(sortlist[j][0])]))

                #获取根据query检索出来的图库中的图片特征
                gan_list_feature = []
                neg_list_feature = []
                for index_gan in range(0, LEN_GAN):
                    gan_list_feature.append(
                        generated_data[index_query * LEN_GAN + index_gan][2])
                gan_list_feature = np.asarray(gan_list_feature)
                for index_gan in range(0, LEN_GAN):
                    neg_list_feature.append(neg_data[index_query * LEN_GAN +
                                                     index_gan][2])
                neg_list_feature = np.asarray(neg_list_feature)

                #根据生成的GAN序列和query的特征进行生成reward
                gan_reward = sess.run(discriminator.reward,
                                      feed_dict={
                                          discriminator.query_data:
                                          current_query_feature_un,
                                          discriminator.gan_data:
                                          gan_list_feature,
                                          discriminator.neg_data:
                                          neg_list_feature
                                      })

                gan_index = np.random.choice(np.arange(
                    len(unsigned_list_pred_index)),
                                             size=LEN_GAN,
                                             p=prob)

                _ = sess.run(generator.gan_updates,
                             feed_dict={
                                 generator.query_data:
                                 current_query_feature_un,
                                 generator.pred_data: pred_feature,
                                 generator.sample_index: gan_index,
                                 generator.reward: gan_reward,
                             })

            G_map = MAP_G(sess, generator, random_query_D_feature,
                          random_query_G_label, pred_label, generated_data,
                          QUERY_TRAIN_SIZE, LEN_GAN)
            if G_map > G_map_best:
                G_map_best = G_map
                print("Best_G_map:", "map_G", G_map)
            print("map:", "map_G", G_map)

            Test_map = MAP_test(sess, generator, query_test_feature,
                                query_test_label, pred_label, pred_feature,
                                QUERY_TEST_SIZE, LEN_UNSIGNED, LEN_GAN)
            if Test_map > Test_map_best:
                Test_map_best = Test_map
                generator.save_model(sess, GAN_MODEL_BEST_FILE)
                print("Best_Test_map:", "map_Test", Test_map)
            print("map:", "map_Test", Test_map)

    # test
    param_best = pickle.load(open(GAN_MODEL_BEST_FILE, 'rb+'),
                             encoding="iso-8859-1")
    assert param_best is not None
    generator_best = GEN(FEATURE_SIZE,
                         G_WEIGHT_DECAY,
                         G_LEARNING_RATE,
                         LEN_UNSIGNED,
                         param=param_best)
    sess = tf.Session(config=config)
    sess.run(tf.initialize_all_variables())
    map_best = MAP_test(sess, generator_best, query_test_feature,
                        query_test_label, pred_label, pred_feature,
                        QUERY_TEST_SIZE, LEN_UNSIGNED, LEN_GAN)
    print("GAN_Best MAP ", map_best)

    sess.close()
예제 #2
0
query_pos_test = ut.get_query_pos(workdir + '/test.txt')



param_best = cPickle.load(open(GAN_MODEL_BEST_FILE))
assert param_best is not None
generator_best = GEN(FEATURE_SIZE, HIDDEN_SIZE, WEIGHT_DECAY, G_LEARNING_RATE, temperature=TEMPERATURE, param=param_best)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.initialize_all_variables())

p_1_best = precision_at_k(sess, generator_best, query_pos_test, query_pos_train, query_url_feature, k=1)
p_3_best = precision_at_k(sess, generator_best, query_pos_test, query_pos_train, query_url_feature, k=3)
p_5_best = precision_at_k(sess, generator_best, query_pos_test, query_pos_train, query_url_feature, k=5)
p_10_best = precision_at_k(sess, generator_best, query_pos_test, query_pos_train, query_url_feature, k=10)

ndcg_1_best = ndcg_at_k(sess, generator_best, query_pos_test, query_pos_train, query_url_feature, k=1)
ndcg_3_best = ndcg_at_k(sess, generator_best, query_pos_test, query_pos_train, query_url_feature, k=3)
ndcg_5_best = ndcg_at_k(sess, generator_best, query_pos_test, query_pos_train, query_url_feature, k=5)
ndcg_10_best = ndcg_at_k(sess, generator_best, query_pos_test, query_pos_train, query_url_feature, k=10)

map_best = MAP(sess, generator_best, query_pos_test, query_pos_train, query_url_feature)

mrr_best = MRR(sess, generator_best, query_pos_test, query_pos_train, query_url_feature)

print("Best ", "p@1 ", p_1_best, "p@3 ", p_3_best, "p@5 ", p_5_best, "p@10 ", p_10_best)
print("Best ", "ndcg@1 ", ndcg_1_best, "ndcg@3 ", ndcg_3_best, "ndcg@5 ", ndcg_5_best, "p@10 ", ndcg_10_best)
print("Best MAP ", map_best)
print("Best MRR ", mrr_best)
예제 #3
0
def main():
    discriminator = DIS(FEATURE_SIZE,
                        HIDDEN_SIZE,
                        WEIGHT_DECAY,
                        D_LEARNING_RATE,
                        param=None)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.initialize_all_variables())

    print('start random negative sampling with log ranking discriminator')
    generate_uniform(DIS_TRAIN_FILE)
    train_size = ut.file_len(DIS_TRAIN_FILE)

    p_best_val = 0.0
    ndcg_best_val = 0.0

    for epoch in range(200):
        index = 1
        while True:
            if index > train_size:
                break
            if index + BATCH_SIZE <= train_size + 1:
                input_pos, input_neg = ut.get_batch_data(
                    DIS_TRAIN_FILE, index, BATCH_SIZE)
            else:
                input_pos, input_neg = ut.get_batch_data(
                    DIS_TRAIN_FILE, index, train_size - index + 1)
            index += BATCH_SIZE

            pred_data = []
            pred_data.extend(input_pos)
            pred_data.extend(input_neg)
            pred_data = np.asarray(pred_data)

            pred_data_label = [1.0] * len(input_pos)
            pred_data_label.extend([0.0] * len(input_neg))
            pred_data_label = np.asarray(pred_data_label)

            _ = sess.run(discriminator.d_updates,
                         feed_dict={
                             discriminator.pred_data: pred_data,
                             discriminator.pred_data_label: pred_data_label
                         })

        p_5 = precision_at_k(sess,
                             discriminator,
                             query_pos_test,
                             query_pos_train,
                             query_url_feature,
                             k=5)
        ndcg_5 = ndcg_at_k(sess,
                           discriminator,
                           query_pos_test,
                           query_pos_train,
                           query_url_feature,
                           k=5)

        if p_5 > p_best_val:
            p_best_val = p_5
            discriminator.save_model(sess, MLE_MODEL_BEST_FILE)
            print("Best: ", " p@5 ", p_5, "ndcg@5 ", ndcg_5)
        elif p_5 == p_best_val:
            if ndcg_5 > ndcg_best_val:
                ndcg_best_val = ndcg_5
                discriminator.save_model(sess, MLE_MODEL_BEST_FILE)
                print("Best: ", " p@5 ", p_5, "ndcg@5 ", ndcg_5)

    sess.close()
    param_best = cPickle.load(open(MLE_MODEL_BEST_FILE))
    assert param_best is not None
    discriminator_best = DIS(FEATURE_SIZE,
                             HIDDEN_SIZE,
                             WEIGHT_DECAY,
                             D_LEARNING_RATE,
                             param=param_best)

    sess = tf.Session(config=config)
    sess.run(tf.initialize_all_variables())

    p_1_best = precision_at_k(sess,
                              discriminator_best,
                              query_pos_test,
                              query_pos_train,
                              query_url_feature,
                              k=1)
    p_3_best = precision_at_k(sess,
                              discriminator_best,
                              query_pos_test,
                              query_pos_train,
                              query_url_feature,
                              k=3)
    p_5_best = precision_at_k(sess,
                              discriminator_best,
                              query_pos_test,
                              query_pos_train,
                              query_url_feature,
                              k=5)
    p_10_best = precision_at_k(sess,
                               discriminator_best,
                               query_pos_test,
                               query_pos_train,
                               query_url_feature,
                               k=10)

    ndcg_1_best = ndcg_at_k(sess,
                            discriminator_best,
                            query_pos_test,
                            query_pos_train,
                            query_url_feature,
                            k=1)
    ndcg_3_best = ndcg_at_k(sess,
                            discriminator_best,
                            query_pos_test,
                            query_pos_train,
                            query_url_feature,
                            k=3)
    ndcg_5_best = ndcg_at_k(sess,
                            discriminator_best,
                            query_pos_test,
                            query_pos_train,
                            query_url_feature,
                            k=5)
    ndcg_10_best = ndcg_at_k(sess,
                             discriminator_best,
                             query_pos_test,
                             query_pos_train,
                             query_url_feature,
                             k=10)

    map_best = MAP(sess, discriminator_best, query_pos_test, query_pos_train,
                   query_url_feature)
    mrr_best = MRR(sess, discriminator_best, query_pos_test, query_pos_train,
                   query_url_feature)

    print("Best ", "p@1 ", p_1_best, "p@3 ", p_3_best, "p@5 ", p_5_best,
          "p@10 ", p_10_best)
    print("Best ", "ndcg@1 ", ndcg_1_best, "ndcg@3 ", ndcg_3_best, "ndcg@5 ",
          ndcg_5_best, "p@10 ", ndcg_10_best)
    print("Best MAP ", map_best)
    print("Best MRR ", mrr_best)
예제 #4
0
def main():
    discriminator = DIS(FEATURE_SIZE,
                        HIDDEN_SIZE,
                        WEIGHT_DECAY,
                        D_LEARNING_RATE,
                        param=None)
    generator = GEN(FEATURE_SIZE,
                    HIDDEN_SIZE,
                    WEIGHT_DECAY,
                    G_LEARNING_RATE,
                    temperature=TEMPERATURE,
                    param=None)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.initialize_all_variables())

    print('start adversarial training')

    p_best_val = 0.0
    ndcg_best_val = 0.0

    for epoch in range(30):
        if epoch >= 0:
            # G generate negative for D, then train D
            print('Training D ...')
            for d_epoch in range(100):
                if d_epoch % 30 == 0:
                    generate_for_d(sess, generator, DIS_TRAIN_FILE)
                    train_size = ut.file_len(DIS_TRAIN_FILE)

                index = 1
                while True:
                    if index > train_size:
                        break
                    if index + BATCH_SIZE <= train_size + 1:
                        input_pos, input_neg = ut.get_batch_data(
                            DIS_TRAIN_FILE, index, BATCH_SIZE)
                    else:
                        input_pos, input_neg = ut.get_batch_data(
                            DIS_TRAIN_FILE, index, train_size - index + 1)
                    index += BATCH_SIZE

                    pred_data = []
                    pred_data.extend(input_pos)
                    pred_data.extend(input_neg)
                    pred_data = np.asarray(pred_data)

                    pred_data_label = [1.0] * len(input_pos)
                    pred_data_label.extend([0.0] * len(input_neg))
                    pred_data_label = np.asarray(pred_data_label)

                    _ = sess.run(discriminator.d_updates,
                                 feed_dict={
                                     discriminator.pred_data: pred_data,
                                     discriminator.pred_data_label:
                                     pred_data_label
                                 })
        # Train G
        print('Training G ...')
        for g_epoch in range(30):
            for query in query_pos_train.keys():
                pos_list = query_pos_train[query]
                pos_set = set(pos_list)
                all_list = query_index_url[query]

                all_list_feature = [
                    query_url_feature[query][url] for url in all_list
                ]
                all_list_feature = np.asarray(all_list_feature)
                all_list_score = sess.run(
                    generator.pred_score,
                    {generator.pred_data: all_list_feature})

                # softmax for all
                exp_rating = np.exp(all_list_score - np.max(all_list_score))
                prob = exp_rating / np.sum(exp_rating)

                prob_IS = prob * (1.0 - LAMBDA)

                for i in range(len(all_list)):
                    if all_list[i] in pos_set:
                        prob_IS[i] += (LAMBDA / (1.0 * len(pos_list)))

                choose_index = np.random.choice(np.arange(len(all_list)),
                                                [5 * len(pos_list)],
                                                p=prob_IS)
                choose_list = np.array(all_list)[choose_index]
                choose_feature = [
                    query_url_feature[query][url] for url in choose_list
                ]
                choose_IS = np.array(prob)[choose_index] / np.array(
                    prob_IS)[choose_index]

                choose_index = np.asarray(choose_index)
                choose_feature = np.asarray(choose_feature)
                choose_IS = np.asarray(choose_IS)

                choose_reward = sess.run(
                    discriminator.reward,
                    feed_dict={discriminator.pred_data: choose_feature})

                _ = sess.run(generator.g_updates,
                             feed_dict={
                                 generator.pred_data: all_list_feature,
                                 generator.sample_index: choose_index,
                                 generator.reward: choose_reward,
                                 generator.important_sampling: choose_IS
                             })

            p_5 = precision_at_k(sess,
                                 generator,
                                 query_pos_test,
                                 query_pos_train,
                                 query_url_feature,
                                 k=5)
            ndcg_5 = ndcg_at_k(sess,
                               generator,
                               query_pos_test,
                               query_pos_train,
                               query_url_feature,
                               k=5)

            if p_5 > p_best_val:
                p_best_val = p_5
                ndcg_best_val = ndcg_5
                generator.save_model(sess, GAN_MODEL_BEST_FILE)
                print("Best:", "gen p@5 ", p_5, "gen ndcg@5 ", ndcg_5)
            elif p_5 == p_best_val:
                if ndcg_5 > ndcg_best_val:
                    ndcg_best_val = ndcg_5
                    generator.save_model(sess, GAN_MODEL_BEST_FILE)
                    print("Best:", "gen p@5 ", p_5, "gen ndcg@5 ", ndcg_5)

    sess.close()
    param_best = cPickle.load(open(GAN_MODEL_BEST_FILE))
    assert param_best is not None
    generator_best = GEN(FEATURE_SIZE,
                         HIDDEN_SIZE,
                         WEIGHT_DECAY,
                         G_LEARNING_RATE,
                         temperature=TEMPERATURE,
                         param=param_best)
    sess = tf.Session(config=config)
    sess.run(tf.initialize_all_variables())

    p_1_best = precision_at_k(sess,
                              generator_best,
                              query_pos_test,
                              query_pos_train,
                              query_url_feature,
                              k=1)
    p_3_best = precision_at_k(sess,
                              generator_best,
                              query_pos_test,
                              query_pos_train,
                              query_url_feature,
                              k=3)
    p_5_best = precision_at_k(sess,
                              generator_best,
                              query_pos_test,
                              query_pos_train,
                              query_url_feature,
                              k=5)
    p_10_best = precision_at_k(sess,
                               generator_best,
                               query_pos_test,
                               query_pos_train,
                               query_url_feature,
                               k=10)

    ndcg_1_best = ndcg_at_k(sess,
                            generator_best,
                            query_pos_test,
                            query_pos_train,
                            query_url_feature,
                            k=1)
    ndcg_3_best = ndcg_at_k(sess,
                            generator_best,
                            query_pos_test,
                            query_pos_train,
                            query_url_feature,
                            k=3)
    ndcg_5_best = ndcg_at_k(sess,
                            generator_best,
                            query_pos_test,
                            query_pos_train,
                            query_url_feature,
                            k=5)
    ndcg_10_best = ndcg_at_k(sess,
                             generator_best,
                             query_pos_test,
                             query_pos_train,
                             query_url_feature,
                             k=10)

    map_best = MAP(sess, generator_best, query_pos_test, query_pos_train,
                   query_url_feature)
    mrr_best = MRR(sess, generator_best, query_pos_test, query_pos_train,
                   query_url_feature)

    print("Best ", "p@1 ", p_1_best, "p@3 ", p_3_best, "p@5 ", p_5_best,
          "p@10 ", p_10_best)
    print("Best ", "ndcg@1 ", ndcg_1_best, "ndcg@3 ", ndcg_3_best, "ndcg@5 ",
          ndcg_5_best, "p@10 ", ndcg_10_best)
    print("Best MAP ", map_best)
    print("Best MRR ", mrr_best)
예제 #5
0
                        k=1)
ndcg_3_best = ndcg_at_k(sess,
                        discriminator_best,
                        query_pos_test,
                        query_pos_train,
                        query_url_feature,
                        k=3)
ndcg_5_best = ndcg_at_k(sess,
                        discriminator_best,
                        query_pos_test,
                        query_pos_train,
                        query_url_feature,
                        k=5)
ndcg_10_best = ndcg_at_k(sess,
                         discriminator_best,
                         query_pos_test,
                         query_pos_train,
                         query_url_feature,
                         k=10)

map_best = MAP(sess, discriminator_best, query_pos_test, query_pos_train,
               query_url_feature)
mrr_best = MRR(sess, discriminator_best, query_pos_test, query_pos_train,
               query_url_feature)

print("Best ", "p@1 ", p_1_best, "p@3 ", p_3_best, "p@5 ", p_5_best, "p@10 ",
      p_10_best)
print("Best ", "ndcg@1 ", ndcg_1_best, "ndcg@3 ", ndcg_3_best, "ndcg@5 ",
      ndcg_5_best, "p@10 ", ndcg_10_best)
print("Best MAP ", map_best)
print("Best MRR ", mrr_best)