Ejemplo n.º 1
0
def load_clef_qrels() -> Dict[str, List[str]]:
    path1 = os.path.join(data_path, "CLEFeHealth2017IRtask", "assessments", "2017", "clef2017_qrels.txt")
    q_rel_d1 = load_qrels_flat(path1)
    path2 = os.path.join(data_path, "CLEFeHealth2017IRtask", "assessments", "2016", "task1.qrels")
    q_rel_d2 = load_qrels_flat(path2)

    def fn(pair_list):
        return list([doc_id for doc_id, score in pair_list if score > 0])
    q_rel_1 = dict_value_map(fn, q_rel_d1)
    q_rel_2 = dict_value_map(fn, q_rel_d2)

    for key in q_rel_2:
        q_rel_1[key].extend(q_rel_2[key])

    return q_rel_1
Ejemplo n.º 2
0
def do_for_robust():
    # Count candidates that appear as positive in training split but negative in dev/test
    judgment_path = os.path.join(data_path, "robust", "qrels.rob04.txt")
    qrels = load_qrels_flat(judgment_path)
    qrels['672'] = []

    ranked_list_path = os.path.join(data_path, "robust",
                                    "rob04.desc.galago.2k.out")
    rlg_all: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(
        ranked_list_path)

    def is_in_split(split, qid):
        if split == "train":
            return int(qid) <= 650
        elif split == "dev":
            return int(qid) >= 651
        else:
            assert False

    def get_ranked_list(split):
        out_rlg = {}
        for qid, rl in rlg_all.items():
            if is_in_split(split, qid):
                out_rlg[qid] = rl[:100]
        return out_rlg

    train_splits = ["train"]
    eval_splits = ["dev"]
    analyze_overlap(get_ranked_list, qrels, train_splits, eval_splits)
Ejemplo n.º 3
0
def do_for_new_perspectrum():
    judgment_path = os.path.join(data_path, "perspective", "qrel.txt")
    qrels = load_qrels_flat(judgment_path)

    candidate_set_name = "pc_qres"

    # candidate_set_name = "default_qres"

    def get_old_ranked_list(split):
        ranked_list_path = os.path.join(output_path, "perspective_experiments",
                                        candidate_set_name,
                                        "{}.txt".format(split))
        rlg: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(
            ranked_list_path)
        return rlg

    rlg_all: Dict[str, List[TrecRankedListEntry]] = dict()
    for split in splits:
        rlg_all.update(get_old_ranked_list(split))

    split_info_path = os.path.join(data_path, "perspective", "new_split.json")
    new_splits: Dict[str, str] = json.load(open(split_info_path, "r"))

    def get_new_ranked_list(split):
        qids: Iterable[str] = load_qid_for_split(new_splits, split)
        new_rlg = {}
        for qid in qids:
            new_rlg[qid] = rlg_all[qid]
        return new_rlg

    train_splits = ["train"]
    eval_splits = ["dev"]
    analyze_overlap(get_new_ranked_list, qrels, train_splits, eval_splits)
Ejemplo n.º 4
0
def main():
    judgment_path = sys.argv[1]
    metric = sys.argv[2]
    ranked_list_path1 = sys.argv[3]
    ranked_list_path2 = sys.argv[4]
    # print
    qrels = load_qrels_flat(judgment_path)

    ranked_list_1: Dict[str,
                        List[TrecRankedListEntry]] = load_ranked_list_grouped(
                            ranked_list_path1)
    ranked_list_2: Dict[str,
                        List[TrecRankedListEntry]] = load_ranked_list_grouped(
                            ranked_list_path2)

    metric_fn = get_metric_fn(metric)

    score_d1 = get_score_per_query(qrels, metric_fn, ranked_list_1)
    score_d2 = get_score_per_query(qrels, metric_fn, ranked_list_2)

    pairs = []
    for key in score_d1:
        try:
            e = (key, score_d1[key], score_d2[key])
            pairs.append(e)
        except KeyError as e:
            pass

    pairs.sort(key=lambda t: t[2] - t[1], reverse=True)

    for query_id, score1, score2 in pairs:
        print("{0} {1:.2f} {2:.2f} {3:.2f}".format(query_id, score2 - score1,
                                                   score1, score2))
Ejemplo n.º 5
0
def main():
    judgment_path = sys.argv[1]
    metric = sys.argv[2]
    ranked_list_path1 = sys.argv[3]
    ranked_list_path2 = sys.argv[4]
    # print
    qrels = load_qrels_flat(judgment_path)

    ranked_list_1: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(ranked_list_path1)
    ranked_list_2: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(ranked_list_path2)

    metric_fn = get_metric_fn(metric)

    score_d1 = get_score_per_query(qrels, metric_fn, ranked_list_1)
    score_d2 = get_score_per_query(qrels, metric_fn, ranked_list_2)

    pairs = []
    for key in score_d1:
        try:
            e = (score_d1[key], score_d2[key])
            pairs.append(e)
        except KeyError as e:
            pass

    if len(pairs) < len(score_d1) or len(pairs) < len(score_d2):
        print("{} matched from {} and {} scores".format(len(pairs), len(score_d1), len(score_d2)))

    l1, l2 = zip(*pairs)
    d, p_value = stats.ttest_rel(l1, l2)
    print("baseline:", average(l1))
    print("treatment:", average(l2))
    print(d, p_value)
Ejemplo n.º 6
0
def split_qids() -> Tuple[List[str], List[str], List[str]]:
    judgment_path = os.path.join(data_path, "perspective", "qrel.txt")
    qrels = load_qrels_flat(judgment_path)

    all_qids = list(qrels.keys())
    n_qids = len(all_qids)

    paired_qids = get_similar_pairs(all_qids, qrels)

    rel_qid_d = defaultdict(list)
    for qid1, qid2 in paired_qids:
        rel_qid_d[qid1].append(qid2)
        rel_qid_d[qid2].append(qid1)

    random.shuffle(all_qids)
    n_claim = len(all_qids)
    # Previous split size
    # 'train': 541,
    # 'dev': 139,
    # 'test': 227,
    min_train_claim = 541
    min_test_claim = 227

    clusters: List[List[str]] = make_clusters(all_qids, rel_qid_d)
    random.shuffle(clusters)

    def pool_cluster(remaining_clusters: List[List[str]], start_idx,
                     n_minimum):
        idx = start_idx
        selected_qids = []
        while idx < len(remaining_clusters) and len(selected_qids) < n_minimum:
            cur_cluster = remaining_clusters[idx]
            print(len(cur_cluster), end=" ")
            selected_qids.extend(cur_cluster)
            idx += 1
        print()
        return selected_qids, idx

    print("train")
    train_qids, last_idx = pool_cluster(clusters, 0, min_train_claim)
    print("test")
    test_qids, last_idx = pool_cluster(clusters, last_idx, min_test_claim)
    dev_qids = lflatten(clusters[last_idx:])
    qid_splits = [train_qids, test_qids, dev_qids]

    assert n_qids == sum(map(len, qid_splits))

    def overlap(l1: List, l2: List) -> Set:
        common = set(l1).intersection(l2)
        return common

    assert not overlap(train_qids, test_qids)
    assert not overlap(train_qids, dev_qids)
    assert not overlap(dev_qids, test_qids)
    for split_qid in qid_splits:
        assert len(set(split_qid)) == len(split_qid)

    return train_qids, dev_qids, test_qids
Ejemplo n.º 7
0
def combine_qrels():
    path1 = os.path.join(data_path, "CLEFeHealth2017IRtask", "assessments", "2017", "clef2017_qrels.txt")
    q_rel_d1 = load_qrels_flat(path1)
    path2 = os.path.join(data_path, "CLEFeHealth2017IRtask", "assessments", "2016", "task1.qrels")
    q_rel_d2 = load_qrels_flat(path2)

    combined: Dict[str, List[Tuple[str, int]]] = {}
    for key in q_rel_d2:
        concat_list = q_rel_d2[key] + q_rel_d1[key]
        new_list = {}
        for doc_id, score in concat_list:
            new_list[doc_id] = int(score)

        l: List[Tuple[str, int]] = list(new_list.items())
        combined[key] = l


    save_path = os.path.join(data_path, "CLEFeHealth2017IRtask", "combined.qrels")
    write_qrels(combined, save_path)
Ejemplo n.º 8
0
def main():
    split = "train"
    print("get query lms")
    query_lms: Dict[str, Counter] = get_query_lms(split)
    candidate_dict: Dict[str,
                         List[QCKCandidate]] = get_candidate_full_text(split)
    q_ranked_list = rank_with_query_lm(query_lms, candidate_dict)
    qrel_path = os.path.join(data_path, "perspective", "evidence_qrel.txt")
    qrels: QRelsFlat = load_qrels_flat(qrel_path)
    score = get_map(q_ranked_list, qrels)
    print(score)
Ejemplo n.º 9
0
def main():
    judgment_path = sys.argv[1]
    save_path = sys.argv[2]
    # print
    qrels = load_qrels_flat(judgment_path)

    def iter():
        for query_id, docs in qrels.items():
            for doc_id, raw_score in docs:
                if raw_score > 0:
                    score = raw_score
                else:
                    score = 0
                yield TrecRelevanceJudgementEntry(query_id, doc_id, score)

    write_trec_relevance_judgement(iter(), save_path)
Ejemplo n.º 10
0
def main():
    print("get query lms")
    split = "train"
    qk_candidate: List[QKUnit] = load_from_pickle("pc_evi_filtered_qk_{}".format(split))
    qk_candidate: List[QKUnit] = load_from_pickle("pc_evidence_qk".format(split))
    candidate_dict: Dict[str, List[QCKCandidate]] = get_candidate_full_text(split)
    query_lms: Dict[str, Counter] = kdp_to_lm(qk_candidate)
    valid_qids: List[str] = list(query_lms.keys())
    target_candidate_dict = {}
    for k, c, in candidate_dict.items():
        if k in valid_qids:
            target_candidate_dict[k] = c
    alpha = 0.5
    print("alpha", alpha)
    q_ranked_list = rank_with_query_lm(query_lms, target_candidate_dict, 100, alpha)
    qrel_path = os.path.join(data_path, "perspective", "evidence_qrel.txt")
    qrels: QRelsFlat = load_qrels_flat(qrel_path)
    score = get_map(q_ranked_list, qrels)
    print(score)
Ejemplo n.º 11
0
def main():
    judgment_path = sys.argv[1]
    ranked_list_path1 = sys.argv[2]
    # print
    qrels = load_qrels_flat(judgment_path)

    ranked_list: Dict[str,
                      List[TrecRankedListEntry]] = load_ranked_list_grouped(
                          ranked_list_path1)
    print(37)

    all_scores_list = []

    all_label_list = []
    per_score_list = []
    k = 20
    for query_id in ranked_list:
        q_ranked_list = ranked_list[query_id]
        try:
            gold_list = qrels[query_id]
            true_gold = list(
                [doc_id for doc_id, score in gold_list if score > 0])

            label_list = []
            seen_docs = []
            for e in q_ranked_list:
                label = 1 if e.doc_id in true_gold else 0
                seen_docs.append(e.doc_id)
                label_list.append(label)

            d1 = dcg(label_list[:k])
            n_true = min(len(true_gold), k)
            idcg = dcg([1] * n_true)
            per_score = d1 / idcg if true_gold else 1
            for doc_id in true_gold:
                if doc_id not in seen_docs:
                    label_list.append(1)
        except KeyError as e:
            per_score = 1
            print("Query not found:", query_id)
        per_score_list.append(per_score)

    print(average(per_score_list))
Ejemplo n.º 12
0
def do_for_perspectrum():
    # Count candidates that appear as positive in training split but negative in dev/test
    judgment_path = os.path.join(data_path, "perspective", "qrel.txt")
    qrels = load_qrels_flat(judgment_path)

    candidate_set_name = "pc_qres"

    # candidate_set_name = "default_qres"

    def get_ranked_list(split):
        ranked_list_path = os.path.join(output_path, "perspective_experiments",
                                        candidate_set_name,
                                        "{}.txt".format(split))
        rlg: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(
            ranked_list_path)
        return rlg

    train_splits = ["train"]
    eval_splits = ["dev"]
    analyze_overlap(get_ranked_list, qrels, train_splits, eval_splits)
Ejemplo n.º 13
0
def split_qids():
    judgment_path = os.path.join(data_path, "perspective", "qrel.txt")
    qrels = load_qrels_flat(judgment_path)

    all_qids = list(qrels.keys())
    n_qids = len(all_qids)

    paired_qids = get_similar_pairs(all_qids, qrels)

    rel_qid_d = defaultdict(list)
    for qid1, qid2 in paired_qids:
        rel_qid_d[qid1].append(qid2)
        rel_qid_d[qid2].append(qid1)

    random.shuffle(all_qids)
    n_claim = len(all_qids)
    # Previous split size
    # 'train': 541,
    # 'dev': 139,
    # 'test': 227,
    min_train_claim = 541
    min_test_claim = 227

    def pool_qids(remaining_qids: List[str], seen_qids: Set[str],
                  n_minimum: float):
        selected_qids = []
        idx = 0
        while len(selected_qids) < n_minimum and idx < len(remaining_qids):
            cur_qid = remaining_qids[idx]
            if cur_qid not in seen_qids:
                qids_to_add = [cur_qid]
                seen_qids.add(cur_qid)

                while qids_to_add:
                    qid_being_added = qids_to_add[0]
                    selected_qids.append(qid_being_added)
                    qids_to_add = qids_to_add[1:]
                    rel_qids = rel_qid_d[qid_being_added]
                    for qid in rel_qids:
                        if qid not in seen_qids:
                            qids_to_add.append(qid)
                            seen_qids.add(qid)

            idx += 1
        return selected_qids, remaining_qids[idx:]

    seen_qids = set()
    train_qids, remaining_qids = pool_qids(all_qids, seen_qids,
                                           min_train_claim)
    test_qids, remaining_qids = pool_qids(remaining_qids, seen_qids,
                                          min_test_claim)
    dev_qids = list([qid for qid in remaining_qids if qid not in seen_qids])

    qid_splits = [train_qids, test_qids, dev_qids]

    assert n_qids == sum(map(len, qid_splits))

    def overlap(l1: List, l2: List) -> Set:
        common = set(l1).intersection(l2)
        return common

    assert not overlap(train_qids, test_qids)
    assert not overlap(train_qids, dev_qids)
    assert not overlap(dev_qids, test_qids)
    for split_qid in qid_splits:
        assert len(set(split_qid)) == len(split_qid)

    return train_qids, dev_qids, test_qids
Ejemplo n.º 14
0
def main():
    first_list_path = sys.argv[1]
    second_list_path = sys.argv[2]
    print("Use {} if available, if not use {}".format(first_list_path,
                                                      second_list_path))
    l1: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(
        first_list_path)
    l2: Dict[str, List[TrecRankedListEntry]] = load_ranked_list_grouped(
        second_list_path)

    judgment_path = sys.argv[3]
    qrels: Dict[QueryID, List[Tuple[DocID,
                                    int]]] = load_qrels_flat(judgment_path)

    def eval_loss(prob, label):
        if label:
            loss = -math.log(prob)
        else:
            loss = -math.log(1 - prob)
        return loss

    def get_loss(l):
        loss_all = []
        for query_id, ranked_list in l.items():
            gold_list = qrels[query_id]
            true_gold: List[str] = list(
                [doc_id for doc_id, score in gold_list if score > 0])
            for e in ranked_list:
                label = e.doc_id in true_gold
                loss = eval_loss(e.score, label)
                loss_all.append(loss)

        return average(loss_all)

    def get_acc(l):
        correctness = []
        for query_id, ranked_list in l.items():
            gold_list = qrels[query_id]
            true_gold: List[str] = list(
                [doc_id for doc_id, score in gold_list if score > 0])
            for e in ranked_list:
                label = e.doc_id in true_gold
                is_correct = (e.score > 0.5) == label
                # print(label, e.score, is_correct)
                correctness.append(int(is_correct))

        return average(correctness)

    def get_tf(l):
        tp = 0
        fp = 0
        tn = 0
        fn = 0

        for query_id, ranked_list in l.items():
            gold_list = qrels[query_id]
            true_gold: List[str] = list(
                [doc_id for doc_id, score in gold_list if score > 0])
            for e in ranked_list:
                label = e.doc_id in true_gold
                pred_true = e.score > 0.5
                if pred_true:
                    if label:
                        tp += 1
                    else:
                        fp += 1
                else:
                    if label:
                        tn += 1
                    else:
                        fn += 1

        return tp, fp, tn, fn

    print("loss1:", get_loss(l1))
    print("loss2:", get_loss(l2))

    print("acc1:", get_acc(l1))
    print("acc2:", get_acc(l2))

    print("1: tp fp tn fn ", get_tf(l1))
    print("2: tp fp tn fn :", get_tf(l2))