def main(): judgment_path = sys.argv[1] metric = sys.argv[2] ranked_list_path1 = sys.argv[3] ranked_list_path2 = sys.argv[4] # print qrels = load_qrels_flat(judgment_path) ranked_list_1: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(ranked_list_path1) ranked_list_2: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(ranked_list_path2) metric_fn = get_metric_fn(metric) score_d1 = get_score_per_query(qrels, metric_fn, ranked_list_1) score_d2 = get_score_per_query(qrels, metric_fn, ranked_list_2) pairs = [] for key in score_d1: try: e = (score_d1[key], score_d2[key]) pairs.append(e) except KeyError as e: pass if len(pairs) < len(score_d1) or len(pairs) < len(score_d2): print("{} matched from {} and {} scores".format(len(pairs), len(score_d1), len(score_d2))) l1, l2 = zip(*pairs) d, p_value = stats.ttest_rel(l1, l2) print("baseline:", average(l1)) print("treatment:", average(l2)) print(d, p_value)
def main(): judgment_path = sys.argv[1] metric = sys.argv[2] ranked_list_path1 = sys.argv[3] ranked_list_path2 = sys.argv[4] # print qrels = load_qrels_flat(judgment_path) ranked_list_1: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( ranked_list_path1) ranked_list_2: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( ranked_list_path2) metric_fn = get_metric_fn(metric) score_d1 = get_score_per_query(qrels, metric_fn, ranked_list_1) score_d2 = get_score_per_query(qrels, metric_fn, ranked_list_2) pairs = [] for key in score_d1: try: e = (key, score_d1[key], score_d2[key]) pairs.append(e) except KeyError as e: pass pairs.sort(key=lambda t: t[2] - t[1], reverse=True) for query_id, score1, score2 in pairs: print("{0} {1:.2f} {2:.2f} {3:.2f}".format(query_id, score2 - score1, score1, score2))
def main(): run_config = json.load(open(sys.argv[1], "r")) l1: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(run_config['first_list']) l2: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(run_config['second_list']) run_name = run_config['run_name'] strategy = run_config['strategy'] save_path = run_config['save_path'] k1 = run_config['k1'] k2 = run_config['k2'] new_entries: Dict[str, List[TrecRankedListEntry]] = l1 qid_list = l1.keys() for key in l2: if key not in qid_list: print("WARNING qid {} is not in the first list".format(key)) for qid in qid_list: if qid not in l2: new_entries[qid] = l1[qid] else: entries1 = l1[qid] entries2 = l2[qid] if strategy == "reciprocal": fused_scores = reciprocal_fusion(entries1, entries2, k1, k2) elif strategy == "weighted_sum": fused_scores = weighted_sum_fusion(entries1, entries2, k1, k2) else: assert False new_entries[qid] = scores_to_ranked_list_entries(fused_scores, run_name, qid) flat_entries: Iterable[TrecRankedListEntry] = flatten(new_entries.values()) write_trec_ranked_list_entry(flat_entries, save_path)
def main(): saved_dir = at_output_dir("perspective_experiments", "clueweb_qres") path1 = os.path.join(saved_dir, "train.txt") path2 = os.path.join(saved_dir, "dev.txt") rlg1 = load_ranked_list_grouped(path1) rlg2 = load_ranked_list_grouped(path2) k = 10 most_common = [] for query_id1 in rlg1: for query_id2 in rlg2: top_k_docs1 = lmap(TrecRankedListEntry.get_doc_id, rlg1[query_id1][:k]) top_k_docs2 = lmap(TrecRankedListEntry.get_doc_id, rlg2[query_id2][:k]) common = set(top_k_docs1).intersection(top_k_docs2) percent_common = len(common) / k if percent_common > 0.1: most_common.append((percent_common, query_id1, query_id2)) most_common.sort(key=get_first, reverse=True) for rate_common, qid1, qid2 in most_common[:10]: print(rate_common, qid1, qid2)
def main(): rlg_proposed_tfidf = load_ranked_list_grouped(sys.argv[1]) rlg_proposed_bm25 = load_ranked_list_grouped(sys.argv[2]) rlg_bert_tfidf = load_ranked_list_grouped(sys.argv[3]) qrel: QRelsDict = load_qrels_structured(sys.argv[4]) # TODO # Q1 ) Is the set of document different? # Q2 ) Say 2 (BM25) is better than 1 (tf-idf), # 1. X=0, in 2 not in 1 # 2. X=1, in 2 not in 1 # 3. X=0, in 1 not in 2 -> FP prediction that BERT(baseline) misses # 4. X=1, in 1 not in 2 cnt = 0 for q in rlg_proposed_tfidf: entries1 = rlg_proposed_tfidf[q] entries2 = rlg_proposed_bm25[q] entries3 = rlg_bert_tfidf[q] e3_d = {e.doc_id: e for e in entries3} def get_doc_set(entries): return set(map(TrecRankedListEntry.get_doc_id, entries)) docs1 = get_doc_set(entries1) docs2 = get_doc_set(entries2) d = qrel[q] rows = [[q]] rows.append([ 'doc_id', 'label', 'in_bm25', '1_rank', '1_score', '3_rank', '3_score' ]) for e in entries1: label = d[e.doc_id] if e.doc_id in d else 0 #if e.doc_id not in docs2: if True: # Case 3 predict_binary = e.rank < 20 try: e3 = e3_d[e.doc_id] row = [ e.doc_id, label, e.doc_id in docs2, e.rank, e.score, e3.rank, e3.score ] rows.append(row) except KeyError: assert cnt == 0 cnt += 1 if len(rows) > 2: print_table(rows)
def filter_with_ranked_list_path(qk_name: str, ranked_list_path: str, threshold, top_k): rlg: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( ranked_list_path) qk_units = load_from_pickle(qk_name) new_qk_units = filter_with_ranked_list(qk_units, rlg, threshold, top_k) return new_qk_units
def main(): input_path = sys.argv[1] save_path = sys.argv[2] l1: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(input_path) new_entries: Dict[str, List[TrecRankedListEntry]] = {} run_name = "Reverse" for qid, ranked_list in l1.items(): raw_ranked_list = [] for e in ranked_list: score = 1 - e.score raw_e = (e.query_id, e.doc_id, score) raw_ranked_list.append(raw_e) raw_ranked_list.sort(key=lambda x: x[2], reverse=True) new_ranked_list = [] for rank, e in enumerate(raw_ranked_list): query_id, doc_id, score = e e_new = TrecRankedListEntry(query_id, doc_id, rank, score, run_name) new_ranked_list.append(e_new) new_entries[qid] = new_ranked_list flat_entries: Iterable[TrecRankedListEntry] = flatten(new_entries.values()) write_trec_ranked_list_entry(flat_entries, save_path)
def get_old_ranked_list(split): ranked_list_path = os.path.join(output_path, "perspective_experiments", candidate_set_name, "{}.txt".format(split)) rlg: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( ranked_list_path) return rlg
def do_for_robust(): # Count candidates that appear as positive in training split but negative in dev/test judgment_path = os.path.join(data_path, "robust", "qrels.rob04.txt") qrels = load_qrels_flat(judgment_path) qrels['672'] = [] ranked_list_path = os.path.join(data_path, "robust", "rob04.desc.galago.2k.out") rlg_all: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( ranked_list_path) def is_in_split(split, qid): if split == "train": return int(qid) <= 650 elif split == "dev": return int(qid) >= 651 else: assert False def get_ranked_list(split): out_rlg = {} for qid, rl in rlg_all.items(): if is_in_split(split, qid): out_rlg[qid] = rl[:100] return out_rlg train_splits = ["train"] eval_splits = ["dev"] analyze_overlap(get_ranked_list, qrels, train_splits, eval_splits)
def main(): judgment_path = sys.argv[1] ranked_list_path = sys.argv[2] metric = sys.argv[3] qrels = load_qrels_flat_per_query(judgment_path) ranked_list: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( ranked_list_path) metric_fn = get_metric_fn(metric) score_per_query_list = [] not_found = 0 for query_id in ranked_list: q_ranked_list = ranked_list[query_id] try: gold_list = qrels[query_id] true_gold = list( [doc_id for doc_id, score in gold_list if score > 0]) score_per_query = metric_fn(q_ranked_list, true_gold) score_per_query_list.append(score_per_query) except KeyError as e: not_found += 1 if not_found: print("{} of {} queires not found".format(not_found, len(ranked_list))) score = average(score_per_query_list) print("{}\t{}".format(metric, score))
def qck_gen(job_name, qk_candidate_name, candidate_ranked_list_path, kdp_ranked_list_path, split): claim_ids = load_claim_ids_for_split(split) cids: List[str] = lmap(str, claim_ids) qk_candidate: List[QKUnit] = load_from_pickle(qk_candidate_name) kdp_ranked_list: Dict[ str, List[TrecRankedListEntry]] = load_ranked_list_grouped( kdp_ranked_list_path) print("cids", len(cids)) print("len(qk_candidate)", len(qk_candidate)) print("Generate instances : ", split) generator = QCKInstGenWScore( get_qck_candidate_from_ranked_list_path(candidate_ranked_list_path), is_correct_factory(), kdp_ranked_list) qk_candidate_train: List[QKUnit] = list( [qk for qk in qk_candidate if qk[0].query_id in cids]) def worker_factory(out_dir): return QCKWorker(qk_candidate_train, generator, out_dir) num_jobs = d_n_claims_per_split2[split] runner = JobRunnerS(job_man_dir, num_jobs, job_name + "_" + split, worker_factory) runner.start()
def main(): claim_text_d: Dict[int, str] = get_all_claim_d() claim_text_d: Dict[str, str] = dict_key_map(str, claim_text_d) evi_dict: Dict[str, str] = dict_key_map(str, load_evidence_dict()) evi_gold_dict: Dict[str, List[int]] = evidence_gold_dict_str_qid() print("V2") def print_entry(entry): evidence_text = evi_dict[entry.doc_id] print("[{}] {}: {}".format(entry.rank, entry.doc_id, evidence_text)) ranked_list_dict = load_ranked_list_grouped(sys.argv[1]) for query, ranked_list in ranked_list_dict.items(): print() claim_id, perspective_id = query.split("_") gold_ids: List[str] = lmap(str, evi_gold_dict[query]) if not gold_ids: print("query {} has no gold".format(query)) continue assert gold_ids claim_text = claim_text_d[claim_id] perspective_text = perspective_getter(int(perspective_id)) pos_entries = [] neg_entries = [] for entry in ranked_list: label = entry.doc_id in gold_ids if label: pos_entries.append(entry) elif entry.rank < 3: neg_entries.append(entry) if not pos_entries: print("gold not in ranked list") continue num_rel = len(pos_entries) correctness = [] for entry in ranked_list[:num_rel]: label = entry.doc_id in gold_ids correctness.append(int(label)) precision = average(correctness) if precision > 0.99: print("Good") continue print("precision at {}: {}".format(num_rel, precision)) print("Claim: ", claim_text) print("perspective_text: ", perspective_text) print(" < GOLD >") foreach(print_entry, pos_entries) print(" < False Positive >") foreach(print_entry, neg_entries)
def main(): first_list_path = sys.argv[1] second_list_path = sys.argv[2] save_path = sys.argv[3] print("Use {} if available, if not use {}".format(first_list_path, second_list_path)) l1: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( first_list_path) l2: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( second_list_path) new_entries: Dict[str, List[TrecRankedListEntry]] = l1 for qid in l2: if qid not in l1: new_entries[qid] = l2[qid] flat_entries: Iterable[TrecRankedListEntry] = flatten(new_entries.values()) write_trec_ranked_list_entry(flat_entries, save_path)
def main(): first_list_path = sys.argv[1] second_list_path = sys.argv[2] save_path = sys.argv[3] print("From {} select query that are in {}".format(first_list_path, second_list_path)) l1: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( first_list_path) l2: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( second_list_path) new_entries: Dict[str, List[TrecRankedListEntry]] = {} for qid in l1: if qid in l2: new_entries[qid] = l1[qid] flat_entries: Iterable[TrecRankedListEntry] = flatten(new_entries.values()) write_trec_ranked_list_entry(flat_entries, save_path)
def main(): first_list_path = sys.argv[1] l: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( first_list_path) new_entries: Dict[str, List[TrecRankedListEntry]] = l flat_entries: Iterable[TrecRankedListEntry] = flatten(new_entries.values()) doc_ids = list(set([e.doc_id for e in flat_entries])) urls_d = get_urls(doc_ids) save_to_pickle(urls_d, "urls_d")
def main(): ranked_list_path = sys.argv[1] output_path = sys.argv[2] k = int(sys.argv[3]) rl: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( ranked_list_path) new_ranked_list = [] for key, value in rl.items(): new_ranked_list.extend(value[:k]) write_trec_ranked_list_entry(new_ranked_list, output_path)
def main2(): rlg_proposed_tfidf = load_ranked_list_grouped(sys.argv[1]) rlg_proposed_bm25 = load_ranked_list_grouped(sys.argv[2]) rlg_bert_tfidf = load_ranked_list_grouped(sys.argv[3]) qrel: QRelsDict = load_qrels_structured(sys.argv[4]) flat_etr1 = [] flat_etr3 = [] for q in rlg_proposed_tfidf: entries1 = rlg_proposed_tfidf[q] entries2 = rlg_proposed_bm25[q] entries3 = rlg_bert_tfidf[q] def get_doc_set(entries): return set(map(TrecRankedListEntry.get_doc_id, entries)) docs2 = get_doc_set(entries2) d = qrel[q] def reform(entries): es = list([e for e in entries if e.doc_id not in docs2]) new_entries = [] for idx, e in enumerate(es): new_entries.append( TrecRankedListEntry(e.query_id, e.doc_id, idx, e.score, e.run_name)) return new_entries etr1 = reform(entries1) flat_etr1.extend(etr1) etr3 = reform(entries3) flat_etr3.extend(etr3) write_trec_ranked_list_entry(flat_etr1, "bm25equi_proposed.txt") write_trec_ranked_list_entry(flat_etr3, "bm25equi_bert.txt")
def main(): first_list_path = sys.argv[1] dir_path = sys.argv[2] save_path = sys.argv[3] l: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( first_list_path) new_entries: Dict[str, List[TrecRankedListEntry]] = l flat_entries: Iterable[TrecRankedListEntry] = flatten(new_entries.values()) html = HtmlVisualizer(save_path) rows = [] for e in flat_entries: ahref = "<a href=\"./{}/{}.html\">{}</a>".format( dir_path, e.doc_id, e.doc_id) row = lmap(Cell, [e.query_id, e.rank, e.score, ahref]) rows.append(row) html.write_table(rows)
def main(): judgment_path = sys.argv[1] ranked_list_path1 = sys.argv[2] # print qrels = load_qrels_flat(judgment_path) ranked_list: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( ranked_list_path1) print(37) all_scores_list = [] all_label_list = [] per_score_list = [] k = 20 for query_id in ranked_list: q_ranked_list = ranked_list[query_id] try: gold_list = qrels[query_id] true_gold = list( [doc_id for doc_id, score in gold_list if score > 0]) label_list = [] seen_docs = [] for e in q_ranked_list: label = 1 if e.doc_id in true_gold else 0 seen_docs.append(e.doc_id) label_list.append(label) d1 = dcg(label_list[:k]) n_true = min(len(true_gold), k) idcg = dcg([1] * n_true) per_score = d1 / idcg if true_gold else 1 for doc_id in true_gold: if doc_id not in seen_docs: label_list.append(1) except KeyError as e: per_score = 1 print("Query not found:", query_id) per_score_list.append(per_score) print(average(per_score_list))
def main(): split = "dev" query_d = dict(load_queries(split)) bm25_module = get_bm25_module() ranked_list_path = at_working_dir("msmarco-doc{}-top100".format(split)) run_name = "BM25_df100" rlg = load_ranked_list_grouped(ranked_list_path) save_path = at_output_dir("ranked_list", "mmd_dev_{}.txt".format(run_name)) te = TimeEstimator(100) out_entries = [] for query_id, entries in rlg.items(): doc_ids = list([e.doc_id for e in entries]) docs = load_per_query_docs(query_id, None) found_doc_ids = list([d.doc_id for d in docs]) not_found_doc_ids = list( [doc_id for doc_id in doc_ids if doc_id not in found_doc_ids]) doc_id_len = len(not_found_doc_ids) if doc_id_len: print("{} docs not found".format(doc_id_len)) query_text = query_d[QueryID(query_id)] def score(doc: MSMarcoDoc): content = doc.title + " " + doc.body return bm25_module.score(query_text, content) scored_docs = list([(d, score(d)) for d in docs]) scored_docs.sort(key=get_second, reverse=True) reranked_entries = [] for rank, (doc, score) in enumerate(scored_docs): e = TrecRankedListEntry(query_id, doc.doc_id, rank, score, run_name) reranked_entries.append(e) out_entries.extend(reranked_entries) te.tick() if len(out_entries) > 100 * 100: break write_trec_ranked_list_entry(out_entries, save_path)
def main(): first_list_path = sys.argv[1] dir_path = sys.argv[2] save_path = sys.argv[3] l: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( first_list_path) new_entries: Dict[str, List[TrecRankedListEntry]] = l def get_html_path_fn(doc_id): return os.path.join(dir_path, "{}.html".format(doc_id)) doc_id_to_url = load_from_pickle("urls_d") flat_entries: Iterable[TrecRankedListEntry] = flatten(new_entries.values()) entries = [ enrich(e, get_html_path_fn, doc_id_to_url) for e in flat_entries ] html = HtmlVisualizer(save_path, additional_styles=[ get_link_highlight_code(), get_bootstrap_include_source() ]) rows = [] head = [ get_table_head_cell("query"), get_table_head_cell("rank"), get_table_head_cell("score"), get_table_head_cell("doc_id"), get_table_head_cell("title", 300), get_table_head_cell("url"), ] for e in entries: html_path = os.path.join(dir_path, "{}.html".format(e.doc_id)) ahref = "<a href=\"{}\" target=\"_blank\">{}</a>".format( html_path, e.doc_id) elem_list = [e.query_id, e.rank, e.score, ahref, e.title, e.url] row = lmap(Cell, elem_list) rows.append(row) html.write_table_with_class(rows, "table")
def main(): l1: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(sys.argv[1]) qrel: Dict[str, Dict[str, int]] = load_qrels_structured(sys.argv[2]) threshold_list = [] ptr = 0.0 for i in range(10): v = ptr + i / 10 threshold_list.append(v) res = [] for t in threshold_list: prec_list = [] recall_list = [] for query, ranked_list in l1.items(): gold_dict = qrel[query] if query in qrel else {} gold_docs = [] for doc_id, label in gold_dict.items(): if label: gold_docs.append(doc_id) pred_list = [] for e in ranked_list: if e.score > t: pred_list.append(e.doc_id) common = set(gold_docs).intersection(set(pred_list)) tp = len(common) prec = tp / len(pred_list) if pred_list else 1 recall = tp / len(gold_docs) if gold_docs else 1 prec_list.append(prec) recall_list.append(recall) prec = average(prec_list) recall = average(recall_list) f1 = get_f1(prec, recall) res.append((t, prec, recall, f1)) for t, prec, recall, f1 in res: print(t, prec, recall, f1)
def main(input_path): claims = get_all_claims() claim_d = claims_to_dict(claims) gold: Dict[int, List[List[int]]] = get_claim_perspective_id_dict() grouped_ranked_list = load_ranked_list_grouped(input_path) def is_correct(qid: str, doc_id: str): return any([int(doc_id) in cluster for cluster in gold[int(qid)]]) top_k = 5 for qid, entries in grouped_ranked_list.items(): n_gold = sum(map(len, gold[int(qid)])) cut_n = min(n_gold, top_k) correctness = list([is_correct(qid, e.doc_id) for e in entries[:cut_n]]) num_correct = sum(lmap(int, correctness)) p_at_k = num_correct / cut_n pid_to_rank: Dict[str, int] = {e.doc_id: e.rank for e in entries} def get_rank(pid: int): if str(pid) in pid_to_rank: return pid_to_rank[str(pid)] else: return "X" if p_at_k < 0.3: print(n_gold) print(p_at_k) print("Claim {} {}".format(qid, claim_d[int(qid)]))## for cluster in gold[int(qid)]: print("-") for pid in cluster: print("[{}]".format(get_rank(pid)), perspective_getter(int(pid))) for e in entries[:50]: correct_str = "Y" if is_correct(qid, e.doc_id) else "N" print("{} {} {}".format(correct_str, e.score, perspective_getter(int(e.doc_id))))
def load_clueweb09_ranked_list() -> Dict[str, List[TrecRankedListEntry]]: rl_path = at_output_dir("clueweb", "clue09_ranked_list.txt") return load_ranked_list_grouped(rl_path)
def main(): non_combine_ranked_list: Dict[ str, List[TrecRankedListEntry]] = load_ranked_list_grouped(sys.argv[1]) judgement: QRelsDict = load_qrels_structured(sys.argv[2]) # TODO : Find the case where scoring first segment would get better score # TODO : Check if the FP, TP document has score in low range or high range # TODO : Check if FP def parse_doc_name(doc_name): tokens = doc_name.split("_") doc_id = "_".join(tokens[:-1]) passage_idx = int(tokens[-1]) return doc_id, passage_idx class Passage(NamedTuple): doc_id: str score: float passage_idx: int def get_doc_id(self): return self.doc_id def get_score(self): return self.score class Entry(NamedTuple): doc_id: str max_score: float max_passage_idx: int num_passage: int passages: List[Passage] for query_id, entries in non_combine_ranked_list.items(): new_e_list = [] for e in entries: doc_id, passage_idx = parse_doc_name(e.doc_id) new_e = Passage(doc_id, e.score, passage_idx) new_e_list.append(new_e) grouped: Dict[str, List[Passage]] = group_by(new_e_list, get_first) get_passage_idx = get_third get_score = get_second grouped_entries = [] doc_id_to_entry = {} score_by_head: List[Tuple[str, float]] = [] for doc_id, scored_passages in grouped.items(): scored_passages.sort(key=get_passage_idx) first_seg_score = scored_passages[0].score score_by_head.append((doc_id, first_seg_score)) scored_passages.sort(key=Passage.get_score, reverse=True) max_score = scored_passages[0].score max_passage_idx = scored_passages[0].passage_idx num_passage = len(scored_passages) e = Entry(doc_id, max_score, max_passage_idx, num_passage, scored_passages) scored_passages.sort(key=get_passage_idx) grouped_entries.append(e) doc_id_to_entry[doc_id] = e rel_d = judgement[query_id] def is_relevant(doc_id): return doc_id in rel_d and rel_d[doc_id] score_by_head.sort(key=get_second, reverse=True) rel_rank_by_head = -1 rel_rank = 0 rel_doc_id = "" for rank, (doc_id, score) in enumerate(score_by_head): if is_relevant(doc_id): rel_rank_by_head = rank rel_doc_id = doc_id grouped_entries.sort(key=lambda x: x.max_score, reverse=True) rel_rank_by_max = -1 for rank, e in enumerate(grouped_entries): if is_relevant(e.doc_id): rel_rank_by_max = rank def get_passage_score_str(passages: List[Passage]): passage_scores: List[float] = lmap(Passage.get_score, passages) scores_str = " ".join(map(two_digit_float, passage_scores)) return scores_str if rel_rank_by_head < rel_rank_by_max: print() print("< Relevant document >") print("Rank by head", rel_rank_by_head) print("Rank by max", rel_rank_by_max) rel_entry = doc_id_to_entry[rel_doc_id] print(get_passage_score_str(rel_entry.passages)) print("Num passages", rel_entry.num_passage) for rank, entry in enumerate(grouped_entries): if len( entry.passages ) > 1 and entry.doc_id != rel_doc_id and rank < rel_rank_by_max: print("< False positive document >") print("Rank by max:", rank) print("doc_id", entry.doc_id) print("Num passages", entry.num_passage) print("max_score", entry.max_score) print("Passages scores: ", get_passage_score_str(entry.passages)) break
def get_candidate_ids_from_ranked_list_path( ranked_list_path) -> Dict[str, List[str]]: rlg = load_ranked_list_grouped(ranked_list_path) return get_candidate_ids_from_ranked_list(rlg)
def main(): first_list_path = sys.argv[1] second_list_path = sys.argv[2] print("Use {} if available, if not use {}".format(first_list_path, second_list_path)) l1: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( first_list_path) l2: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped( second_list_path) judgment_path = sys.argv[3] qrels: Dict[QueryID, List[Tuple[DocID, int]]] = load_qrels_flat(judgment_path) def eval_loss(prob, label): if label: loss = -math.log(prob) else: loss = -math.log(1 - prob) return loss def get_loss(l): loss_all = [] for query_id, ranked_list in l.items(): gold_list = qrels[query_id] true_gold: List[str] = list( [doc_id for doc_id, score in gold_list if score > 0]) for e in ranked_list: label = e.doc_id in true_gold loss = eval_loss(e.score, label) loss_all.append(loss) return average(loss_all) def get_acc(l): correctness = [] for query_id, ranked_list in l.items(): gold_list = qrels[query_id] true_gold: List[str] = list( [doc_id for doc_id, score in gold_list if score > 0]) for e in ranked_list: label = e.doc_id in true_gold is_correct = (e.score > 0.5) == label # print(label, e.score, is_correct) correctness.append(int(is_correct)) return average(correctness) def get_tf(l): tp = 0 fp = 0 tn = 0 fn = 0 for query_id, ranked_list in l.items(): gold_list = qrels[query_id] true_gold: List[str] = list( [doc_id for doc_id, score in gold_list if score > 0]) for e in ranked_list: label = e.doc_id in true_gold pred_true = e.score > 0.5 if pred_true: if label: tp += 1 else: fp += 1 else: if label: tn += 1 else: fn += 1 return tp, fp, tn, fn print("loss1:", get_loss(l1)) print("loss2:", get_loss(l2)) print("acc1:", get_acc(l1)) print("acc2:", get_acc(l2)) print("1: tp fp tn fn ", get_tf(l1)) print("2: tp fp tn fn :", get_tf(l2))
def main(): # load queires and candidate (from qrel? from BM25 ?) # write html # 1. Query # 2. Doc ID # 3. Snippet with most keyword match (BM25 score) # 4. scrollable component ranked_list_path = os.path.join(output_path, "ranked_list", "robust_V_10K_10000.txt") bert_ranked_list = load_ranked_list_grouped(ranked_list_path) queries: Dict[str, str] = load_robust04_desc2() qck_queries = to_qck_queries(queries) qrels = load_robust04_qrels() candidates_d = load_candidate_d() # save_to_pickle(candidates_d, "candidate_viewer_candidate_d") # candidates_d = load_from_pickle("candidate_viewer_candidate_d") style = [get_collapsible_css(), get_scroll_css()] # html = HtmlVisualizer( "robust_V_predictions.html", additional_styles=style, ) def is_perfect(judgement, ranked_list): label_list = get_labels(judgement, ranked_list) all_relevant = True for l in label_list: if not l: all_relevant = False if l: if not all_relevant: return False return True def get_labels(judgement, ranked_list): label_list = [] for e in ranked_list: doc_id = e.doc_id if doc_id in judgement: label = judgement[doc_id] else: label = 0 label_list.append(label) return label_list def p_at_k(judgement, ranked_list, k=10): label_list = get_labels(judgement, ranked_list) num_correct = sum([1 if label else 0 for label in label_list[:k]]) return num_correct / k for qid in bert_ranked_list: if qid in candidates_d: if qid not in qrels: continue judgement = qrels[qid] q_text = queries[qid] ranked_list = bert_ranked_list[qid] if is_perfect(judgement, ranked_list): continue html.write_div_open() text = "{0}: {1} ({2:.2f})".format(qid, q_text, p_at_k(judgement, ranked_list)) html.write_elem( "button", text, "collapsible", ) html.write_div_open("content") doc_text_d = dict(candidates_d[qid]) for e in ranked_list: #tokens = doc_tokens[e.doc_id] doc_id = e.doc_id if doc_id in judgement: label = judgement[doc_id] else: label = 0 style = "font-size: 13px; padding: 8px;" if label: style += " background-color: DarkGreen" else: style += " background-color: DarkRed" text = "{0}] {1} ({2:.2f})".format(e.rank, doc_id, e.score) html.write_elem("p", text, "collapsible", style) #text = pretty_tokens(tokens, True) doc_text = doc_text_d[doc_id] html.write_div(doc_text, "c_content") html.write_div_close() html.write_div_close() html.write_script(get_collapsible_script()) html.close()