示例#1
0
def main(config):
    info_dir = config['info_path']
    prediction_file = config['pred_path']

    f_handler = get_format_handler("qck")
    info = load_combine_info_jsons(info_dir, f_handler.get_mapping(),
                                   f_handler.drop_kdp())
    data: List[Dict] = join_prediction_with_info(prediction_file, info,
                                                 ["data_id", "logits"])
    out_entries: List[QCKOutEntry] = lmap(QCKOutEntry.from_dict, data)
    qrel: Dict[str, Dict[str,
                         int]] = load_qrels_structured(config['qrel_path'])

    def get_label(query_id, candi_id):
        if candi_id in qrel[query_id]:
            return qrel[query_id][candi_id]
        else:
            return 0

    def logit_to_score_softmax(logit):
        return scipy.special.softmax(logit)[1]

    grouped: Dict[str,
                  List[QCKOutEntry]] = group_by(out_entries,
                                                lambda x: x.query.query_id)
    for query_id, items in grouped.items():
        raw_kdp_list = [(x.kdp.doc_id, x.kdp.passage_idx) for x in items]
        kdp_list = unique_list(raw_kdp_list)

        raw_candi_id_list = [x.candidate.id for x in items]
        candi_id_list = unique_list(raw_candi_id_list)

        logit_d = {(x.candidate.id, (x.kdp.doc_id, x.kdp.passage_idx)):
                   x.logits
                   for x in items}
        labels = [get_label(query_id, candi_id) for candi_id in candi_id_list]
        head_row0 = [" "] + labels
        head_row1 = [" "] + candi_id_list
        rows = [head_row0, head_row1]
        for kdp_sig in kdp_list:
            row = [kdp_sig]
            for candi_id in candi_id_list:
                try:
                    score = logit_to_score_softmax(logit_d[candi_id, kdp_sig])
                    score_str = "{0:.2f}".format(score)
                except KeyError:
                    score_str = "-"
                row.append(score_str)
            rows.append(row)

        print(query_id)
        print_table(rows)
示例#2
0
def main(pred_file_path: str,
         info_file_path: str,
         info_file_path2: str,
         save_name: str,
                        input_type: str,
                        qrel_path: str,
                        ):

    judgement = load_qrels_structured(qrel_path)
    def get_label(key):
        query_id, doc_id = key
        try:
            return judgement[query_id][doc_id]
        except KeyError:
            return 0

    f_handler = get_format_handler(input_type)
    info: Dict = load_combine_info_jsons(info_file_path, f_handler.get_mapping(), f_handler.drop_kdp())

    info2: Dict = load_combine_info_jsons(info_file_path2, f_handler.get_mapping(), f_handler.drop_kdp())
    doc_length = get_doc_length_info(info2)
    key_logit = "logits"

    data: List[Dict] = join_prediction_with_info(pred_file_path, info, ["data_id", key_logit])

    grouped = group_by(data, f_handler.get_pair_id)

    cnt = Counter()
    for key, entries in grouped.items():
        if not get_label(key):
            continue
        seg_groups = {}
        for e in entries:
            probs = scipy.special.softmax(e['logits'])[:, 1]
            seg_groups[e['idx']] = probs

        indices = list(seg_groups.keys())
        indices.sort()
        assert max(indices) == len(indices) - 1
        all_probs = []
        for seg_group_idx in seg_groups.keys():
            all_probs.extend(seg_groups[seg_group_idx])

        num_seg = doc_length[key]
        max_idx = np.argmax(all_probs[:num_seg])


        cnt[(max_idx, num_seg)] += 1

    save_to_pickle(cnt, save_name)
示例#3
0
def save_to_common_path(pred_file_path: str, info_file_path: str,
                        run_name: str, input_type: str, max_entry: int,
                        score_type: str, shuffle_sort: bool):
    f_handler = get_format_handler(input_type)
    info: Dict = load_combine_info_jsons(info_file_path,
                                         f_handler.get_mapping(),
                                         f_handler.drop_kdp())
    print("Info has {} entries".format(len(info)))
    ranked_list = summarize_score(info, pred_file_path, f_handler, score_type)
    save_dir = os.path.join(output_path, "ranked_list")
    exist_or_mkdir(save_dir)
    save_path = os.path.join(save_dir, run_name + ".txt")
    write_trec_ranked_list_entry(ranked_list, save_path)
    print("Saved at : ", save_path)
 def __init__(self, queries, qid_list, probe_config):
     self.long_seg_score_path_format = at_output_dir("rqd", "rqd_{}.score")
     self.short_seg_score_path_format = at_output_dir("rqd", "rqd_sm_{}.score")
     info_file_path = at_output_dir("robust", "seg_info")
     f_handler = get_format_handler("qc")
     self.f_handler = f_handler
     self.info: Dict = load_combine_info_jsons(info_file_path, f_handler.get_mapping(), f_handler.drop_kdp())
     self.doc_piece_score_d: Dict[Tuple[str, str], List[ScoredPieceFromPair]] = {}
     self.prepared_qids = set()
     self.probe_config = probe_config
     self.queries = queries
     self.tokenizer = get_tokenizer()
     self.qid_list: List[str] = qid_list
     self.not_found_cnt = 0
示例#5
0
def load_scores(info_file_path, prediction_file_path):
    input_type = "qc"
    f_handler = get_format_handler(input_type)
    tprint("Loading json info")
    info: Dict = load_combine_info_jsons(info_file_path,
                                         f_handler.get_mapping(),
                                         f_handler.drop_kdp())
    key_logit = "logits"
    tprint("Reading predictions...")
    data: List[Dict] = join_prediction_with_info(prediction_file_path, info,
                                                 ["data_id", key_logit])
    grouped: Dict[Tuple[str, str],
                  List[Dict]] = group_by(data, f_handler.get_pair_id)
    print("number of groups:", len(grouped))
    return grouped
示例#6
0
def main(info_path, input_type, label_dict_path, save_path):
    f_handler = get_format_handler(input_type)
    info: Dict[str, Dict] = load_combine_info_jsons(info_path, f_handler.get_mapping(), f_handler.drop_kdp())
    label_dict: Dict[Tuple[str, str], bool] = load_pickle_from(label_dict_path)

    l = []
    for entry in info.values():
        key = f_handler.get_pair_id(entry)
        query_id, candidate_id = key
        if key in label_dict:
            correctness = label_dict[key]
        else:
            correctness = False
        e = TrecRelevanceJudgementEntry(query_id, candidate_id, int(correctness))
        l.append(e)

    write_trec_relevance_judgement(l, save_path)
示例#7
0
def save_to_common_path(pred_file_path: str, info_file_path: str,
                        run_name: str, input_type: str, max_entry: int,
                        combine_strategy: str, score_type: str,
                        shuffle_sort: bool):
    tprint("Reading info...")
    f_handler = get_format_handler(input_type)
    info: Dict = load_combine_info_jsons(info_file_path,
                                         f_handler.get_mapping(),
                                         f_handler.drop_kdp())
    tprint("Info has {} entries".format(len(info)))
    score_d = get_score_d(pred_file_path, info, f_handler, combine_strategy,
                          score_type)
    ranked_list = scrore_d_to_trec_style_predictions(score_d, run_name,
                                                     max_entry, shuffle_sort)

    save_dir = os.path.join(output_path, "ranked_list")
    exist_or_mkdir(save_dir)
    save_path = os.path.join(save_dir, run_name + ".txt")
    write_trec_ranked_list_entry(ranked_list, save_path)
    tprint("Saved at : ", save_path)
def save_over_multiple_files(pred_file_list: List[str], info_file_path: str,
                             run_name: str, input_type: str, max_entry: int,
                             combine_strategy: str, score_type: str):
    f_handler = get_format_handler(input_type)
    info: Dict = load_combine_info_jsons(info_file_path,
                                         f_handler.get_mapping(),
                                         f_handler.drop_kdp())
    print("Info has {} entries".format(len(info)))

    score_d = {}
    for pred_file_path in pred_file_list:
        d = get_score_d(pred_file_path, info, f_handler, combine_strategy,
                        score_type)
        score_d.update(d)
    ranked_list = scrore_d_to_trec_style_predictions(score_d, run_name,
                                                     max_entry)
    save_dir = os.path.join(output_path, "ranked_list")
    exist_or_mkdir(save_dir)
    save_path = os.path.join(save_dir, run_name + ".txt")
    write_trec_ranked_list_entry(ranked_list, save_path)
    print("Saved at : ", save_path)
示例#9
0
def main():
    n_factor = 16
    step_size = 16
    max_seq_length = 128
    max_seq_length2 = 128 - 16
    batch_size = 8
    info_file_path = at_output_dir("robust", "seg_info")
    queries = load_robust_04_query("desc")
    qid_list = get_robust_qid_list()

    f_handler = get_format_handler("qc")
    info: Dict = load_combine_info_jsons(info_file_path,
                                         f_handler.get_mapping(),
                                         f_handler.drop_kdp())
    print(len(info))
    tokenizer = get_tokenizer()

    for job_idx in [1]:
        qid = qid_list[job_idx]
        query = queries[str(qid)]
        q_term_length = len(tokenizer.tokenize(query))
        data_path1 = os.path.join(output_path, "robust",
                                  "windowed_{}.score".format(job_idx))
        data_path2 = os.path.join(output_path, "robust",
                                  "windowed_small_{}.score".format(job_idx))
        data1 = OutputViewer(data_path1, n_factor, batch_size)
        data2 = OutputViewer(data_path2, n_factor, batch_size)
        segment_len = max_seq_length - 3 - q_term_length
        segment_len2 = max_seq_length2 - 3 - q_term_length

        outputs = []
        for d1, d2 in zip(data1, data2):
            # for each query, doc pairs
            cur_info1 = info[d1['data_id']]
            cur_info2 = info[d2['data_id']]
            query_doc_id1 = f_handler.get_pair_id(cur_info1)
            query_doc_id2 = f_handler.get_pair_id(cur_info2)

            assert query_doc_id1 == query_doc_id2

            doc = d1['doc']
            probs = get_probs(d1['logits'])
            probs2 = get_probs(d2['logits'])
            n_pred_true = np.count_nonzero(np.less(0.5, probs))
            print(n_pred_true, len(probs))

            seg_scores: List[Tuple[int, int, float]] = get_piece_scores(
                n_factor, probs, segment_len, step_size)
            seg_scores2: List[Tuple[int, int, float]] = get_piece_scores(
                n_factor, probs2, segment_len2, step_size)
            ss_list = []
            for st, ed, score in seg_scores:
                try:
                    st2, ed2, score2 = find_where(lambda x: x[1] == ed,
                                                  seg_scores2)
                    assert ed == ed2
                    assert st < st2
                    tokens = tokenizer.convert_ids_to_tokens(doc[st:st2])
                    diff = score - score2
                    ss = ScoredPiece(st, st2, diff, tokens)
                    ss_list.append(ss)
                except StopIteration:
                    pass
            outputs.append((probs, probs2, query_doc_id1, ss_list))

        html = HtmlVisualizer("windowed.html")

        for probs, probs2, query_doc_id, ss_list in outputs:
            html.write_paragraph(str(query_doc_id))
            html.write_paragraph("Query: " + query)

            ss_list.sort(key=lambda ss: ss.st)
            prev_end = None
            cells = []
            prob_str1 = lmap(two_digit_float, probs)
            prob_str1 = ["8.88"] + prob_str1
            prob_str2 = lmap(two_digit_float, probs2)
            html.write_paragraph(" ".join(prob_str1))
            html.write_paragraph(" ".join(prob_str2))

            for ss in ss_list:
                if prev_end is not None:
                    assert prev_end == ss.st
                else:
                    print(ss.st)

                score = abs(int(100 * ss.score))
                color = "B" if score > 0 else "R"
                cells.extend(
                    [Cell(t, score, target_color=color) for t in ss.tokens])
                prev_end = ss.ed

            html.multirow_print(cells)
示例#10
0
def prec_recall(pred_file_path: str, info_file_path: str, input_type: str,
                score_type: str, qrel_path: str):
    judgments_raw: Dict[str, List[Tuple[str,
                                        int]]] = load_qrels_flat(qrel_path)
    judgments = dict_value_map(dict, judgments_raw)

    grouped = load_cache("ck_based_analysis")
    key_logit = "logits"

    if grouped is None:
        f_handler = get_format_handler(input_type)
        info: Dict = load_combine_info_jsons(info_file_path,
                                             f_handler.get_mapping(),
                                             f_handler.drop_kdp())
        data: List[Dict] = join_prediction_with_info(pred_file_path, info,
                                                     ["data_id", key_logit])
        grouped = group_by(data, get_qk_pair_id)

    def get_score(entry):
        return get_score_from_logit(score_type, entry[key_logit])

    def get_label(query_id, candidate_id):
        judge_dict = judgments[query_id]
        if candidate_id in judge_dict:
            return judge_dict[candidate_id]
        else:
            return 0

    head = [
        "query_id", "kdp_id", "accuracy", "precision", "recall", "f1", "tp",
        "fp", "tn", "fn"
    ]
    rows = [head]
    for pair_id, items in grouped.items():
        query_id, kdp_id = pair_id
        if query_id not in judgments:
            continue

        e_list: List[Tuple[str, float]] = []

        labels = []
        predictions = []
        for item in items:
            score = get_score(item)
            doc_part_id = item['candidate'].id
            doc_id = get_doc_id(doc_part_id)
            e = (doc_id, score)
            e_list.append(e)
            label = bool(get_label(query_id, doc_id))
            labels.append(label)
            prediction = score > 0.5
            predictions.append(prediction)

        scores = get_acc_prec_recall(predictions, labels)

        row = [
            query_id, kdp_id, scores['accuracy'], scores['precision'],
            scores['recall'], scores['f1'], scores['tp'], scores['fp'],
            scores['tn'], scores['fn']
        ]
        rows.append(row)
    print_table(rows)