Exemplo n.º 1
0
def test(model, test_dataset, dictionary, sess):
    batches_idx = helper.get_batches_idx(len(test_dataset), args.batch_size)
    print('number of test batches = ', len(batches_idx))

    num_batches = len(batches_idx)
    predicts, targets = [], []
    map, mrr, ndcg_1, ndcg_3, ndcg_5, ndcg_10 = 0, 0, 0, 0, 0, 0
    for batch_no in range(1, num_batches + 1):  #1,...,num_batches
        batch_idx = batches_idx[batch_no - 1]
        batch_data = [test_dataset.dataset[i] for i in batch_idx]

        #将一批数据转换为模型输入的格式
        (hist_query_input, hist_doc_input, session_num, hist_query_num,
         hist_query_len, hist_click_num, hist_doc_len, cur_query_input,
         cur_doc_input, cur_query_num, cur_query_len, cur_click_num,
         cur_doc_len, query, q_len, doc, d_len, y, next_q, next_q_len,
         maximum_iterations) = helper.batch_to_tensor(batch_data,
                                                      args.max_query_len,
                                                      args.max_doc_len)

        indices, slots_num = model.get_memory_input(session_num)

        feed_dict = {
            model.hist_query_input: hist_query_input,
            model.hist_doc_input: hist_doc_input,
            model.session_num: session_num,
            model.hist_query_num: hist_query_num,
            model.hist_query_len: hist_query_len,
            model.hist_click_num: hist_click_num,
            model.hist_doc_len: hist_doc_len,
            model.cur_query_input: cur_query_input,
            model.cur_doc_input: cur_doc_input,
            model.cur_query_num: cur_query_num,
            model.cur_query_len: cur_query_len,
            model.cur_click_num: cur_click_num,
            model.cur_doc_len: cur_doc_len,
            model.q: query,
            model.q_len: q_len,
            model.d: doc,
            model.d_len: d_len,
            model.indices: indices,
            model.slots_num: slots_num,
            model.maximum_iterations: maximum_iterations
        }

        click_prob_, predicting_ids_, predicting_len_ = sess.run(
            [model.click_prob, model.predicting_ids, model.predicting_len],
            feed_dict=feed_dict)

        map += mean_average_precision(click_prob_, y)
        mrr += MRR(click_prob_, y)
        ndcg_1 += NDCG(click_prob_, y, 1)
        ndcg_3 += NDCG(click_prob_, y, 3)
        ndcg_5 += NDCG(click_prob_, y, 5)
        ndcg_10 += NDCG(click_prob_, y, 10)

        batch_predicting_text = helper.generate_predicting_text(
            predicting_ids_, predicting_len_, dictionary)
        batch_target_text, batch_query_text = helper.generate_target_text(
            batch_data, dictionary, args.max_query_len)
        predicts += batch_predicting_text
        targets += batch_target_text

    map = map / num_batches
    mrr = mrr / num_batches
    ndcg_1 = ndcg_1 / num_batches
    ndcg_3 = ndcg_3 / num_batches
    ndcg_5 = ndcg_5 / num_batches
    ndcg_10 = ndcg_10 / num_batches

    print('MAP - ', map)
    print('MRR - ', mrr)
    print('NDCG@1 - ', ndcg_1)
    print('NDCG@3 - ', ndcg_3)
    print('NDCG@5 - ', ndcg_5)
    print('NDCG@10 - ', ndcg_10)

    print("targets size = ", len(targets))
    print("predicts size = ", len(predicts))

    multi_bleu.print_multi_bleu(predicts, targets)
Exemplo n.º 2
0
                      args.max_example,
                      whole_session=True)
    print('test set size = ', len(test_corpus.data))

    targets, candidates = [], []
    if args.attn_type:
        fw = open(args.save_path + 'seq2seq_attn_predictions.txt', 'w')
    else:
        fw = open(args.save_path + 'seq2seq_predictions.txt', 'w')
    for prev_q, current_q in test_corpus.data:
        q1_var, q1_len, q2_var, q2_len = helper.batch_to_tensor(
            [(prev_q, current_q)],
            dictionary,
            reverse=args.reverse,
            iseval=True)
        if args.cuda:
            q1_var = q1_var.cuda()  # batch_size x max_len
            q2_var = q2_var.cuda()  # batch_size x max_len
            q2_len = q2_len.cuda()  # batch_size

        target = generate_next_query(model, q1_var, q1_len, dictionary)
        candidate = " ".join(current_q.query_terms[1:-1])
        targets.append(target)
        candidates.append(candidate)
        fw.write(candidate + '\t' + target + '\n')
    fw.close()

    print("target size = ", len(targets))
    print("candidate size = ", len(candidates))
    multi_bleu.print_multi_bleu(targets, candidates)