def passage_dist_eval(args, model, tokenizer): base_path = args.data_dir passage_path = os.path.join(base_path, "collection.tsv") queries_path = os.path.join(base_path, "queries.dev.small.tsv") def fn(line, i): return dual_process_fn(line, i, tokenizer, args) # top1000_path = os.path.join(base_path, "top1000.dev.tsv") # top1k_qid_pid = parse_top_dev(top1000_path, qid_col=0, pid_col=1) mrr_ref_path = os.path.join(base_path, "qrels.dev.small.tsv") ref_dict = load_reference(mrr_ref_path) full_ranking_mrr = combined_dist_eval(args, model, queries_path, passage_path, fn, fn, ref_dict) return 0.0, full_ranking_mrr
def passage_dist_eval(args, model, tokenizer): base_path = args.data_dir passage_path = os.path.join(base_path, "msmarco-docs.tsv") queries_path = os.path.join(base_path, "msmarco-docdev-queries.tsv") def fn(line, i): return dual_process_fn_doc(line, i, tokenizer, args) top1000_path = os.path.join(base_path, "msmarco-docdev-top100") top1k_qid_pid = parse_top_dev(top1000_path, qid_col=0, pid_col=2) mrr_ref_path = os.path.join(base_path, "msmarco-docdev-qrels.tsv") ref_dict = load_reference(mrr_ref_path, data_type=0) reranking_mrr, full_ranking_mrr = combined_dist_eval( args, model, queries_path, passage_path, fn, fn, top1k_qid_pid, ref_dict) return reranking_mrr, full_ranking_mrr
def passage_dist_eval_last(args, model, tokenizer): base_path = args.data_dir passage_path = os.path.join(base_path, "collection.tsv") queries_path = os.path.join(base_path, "queries.dev.small.tsv") def fn(line, i): return dual_process_fn(line, i, tokenizer, args) top1000_path = os.path.join(base_path, "top1000.dev.tsv") top1k_qid_pid = parse_top_dev(top1000_path, qid_col=0, pid_col=1) mrr_ref_path = os.path.join(base_path, "qrels.dev.small.tsv") ref_dict = load_reference(mrr_ref_path) # query_positive_id_path = os.path.join(base_path, "dev-qrel.tsv") # dev_query_positive_id = {} # with open(query_positive_id_path, 'r', encoding='utf8') as f: # tsvreader = csv.reader(f, delimiter="\t") # for [topicid, docid, rel] in tsvreader: # topicid = int(topicid) # docid = int(docid) # if topicid not in dev_query_positive_id: # dev_query_positive_id[topicid] = {} # dev_query_positive_id[topicid][docid] = int(rel) dev_query_positive_id = {} for item in ref_dict: if item not in dev_query_positive_id: dev_query_positive_id[item] = {} #assert len(ref_dict[item])==1,ref_dict[item] for item2 in ref_dict[item]: dev_query_positive_id[item][item2] = 1 print('read ok...') reranking_mrr, full_ranking_mrr, recall_1000 = combined_dist_eval_last( args, model, queries_path, passage_path, fn, fn, top1k_qid_pid, ref_dict, dev_query_positive_id) return reranking_mrr, full_ranking_mrr, recall_1000