示例#1
0
def passage_dist_eval(args, model, tokenizer):
    base_path = args.data_dir
    passage_path = os.path.join(base_path, "collection.tsv")
    queries_path = os.path.join(base_path, "queries.dev.small.tsv")

    def fn(line, i):
        return dual_process_fn(line, i, tokenizer, args)

    # top1000_path = os.path.join(base_path, "top1000.dev.tsv")
    # top1k_qid_pid = parse_top_dev(top1000_path, qid_col=0, pid_col=1)

    mrr_ref_path = os.path.join(base_path, "qrels.dev.small.tsv")
    ref_dict = load_reference(mrr_ref_path)

    full_ranking_mrr = combined_dist_eval(args, model, queries_path,
                                          passage_path, fn, fn, ref_dict)
    return 0.0, full_ranking_mrr
示例#2
0
def passage_dist_eval(args, model, tokenizer):
    base_path = args.data_dir
    passage_path = os.path.join(base_path, "msmarco-docs.tsv")
    queries_path = os.path.join(base_path, "msmarco-docdev-queries.tsv")

    def fn(line, i):
        return dual_process_fn_doc(line, i, tokenizer, args)

    top1000_path = os.path.join(base_path, "msmarco-docdev-top100")
    top1k_qid_pid = parse_top_dev(top1000_path, qid_col=0, pid_col=2)

    mrr_ref_path = os.path.join(base_path, "msmarco-docdev-qrels.tsv")
    ref_dict = load_reference(mrr_ref_path, data_type=0)

    reranking_mrr, full_ranking_mrr = combined_dist_eval(
        args, model, queries_path, passage_path, fn, fn, top1k_qid_pid,
        ref_dict)
    return reranking_mrr, full_ranking_mrr
示例#3
0
def passage_dist_eval_last(args, model, tokenizer):
    base_path = args.data_dir
    passage_path = os.path.join(base_path, "collection.tsv")
    queries_path = os.path.join(base_path, "queries.dev.small.tsv")

    def fn(line, i):
        return dual_process_fn(line, i, tokenizer, args)

    top1000_path = os.path.join(base_path, "top1000.dev.tsv")
    top1k_qid_pid = parse_top_dev(top1000_path, qid_col=0, pid_col=1)

    mrr_ref_path = os.path.join(base_path, "qrels.dev.small.tsv")
    ref_dict = load_reference(mrr_ref_path)

    # query_positive_id_path = os.path.join(base_path, "dev-qrel.tsv")
    # dev_query_positive_id = {}

    # with open(query_positive_id_path, 'r', encoding='utf8') as f:
    #     tsvreader = csv.reader(f, delimiter="\t")
    #     for [topicid, docid, rel] in tsvreader:
    #         topicid = int(topicid)
    #         docid = int(docid)
    #         if topicid not in dev_query_positive_id:
    #             dev_query_positive_id[topicid] = {}
    #         dev_query_positive_id[topicid][docid] = int(rel)
    dev_query_positive_id = {}
    for item in ref_dict:
        if item not in dev_query_positive_id:
            dev_query_positive_id[item] = {}
        #assert len(ref_dict[item])==1,ref_dict[item]
        for item2 in ref_dict[item]:
            dev_query_positive_id[item][item2] = 1
    print('read ok...')
    reranking_mrr, full_ranking_mrr, recall_1000 = combined_dist_eval_last(
        args, model, queries_path, passage_path, fn, fn, top1k_qid_pid,
        ref_dict, dev_query_positive_id)

    return reranking_mrr, full_ranking_mrr, recall_1000