예제 #1
0
파일: eval.py 프로젝트: clover3/Chair
def collect_failure(split, scorer: Callable[[Passage, List[Passage]],
                                            List[NamedNumber]],
                    condition: EvalCondition):
    problems, candidate_pool_d = prepare_eval_data(split)

    problems = problems[:100]
    payload: List[Passage] = get_eval_payload_from_dp(problems)
    for query, problem in zip(payload, problems):
        p = problem
        candidate_ids: List[ArguDataID] = retrieve_candidate(
            query, split, condition)
        candidate: List[Passage] = list(
            [candidate_pool_d[x] for x in candidate_ids])
        scores: List[NamedNumber] = scorer(query, candidate)
        best_idx = get_max_idx(scores)
        pred_item: Passage = candidate[best_idx]
        gold_id = p.text2.id
        pred_id = pred_item.id
        correct = gold_id == pred_id
        content_equal = (p.text2.text == pred_item.text)
        correct = correct or content_equal
        gold_idx_l = list(
            [idx for idx, c in enumerate(candidate) if c.id == gold_id])
        gold_idx = gold_idx_l[0] if gold_idx_l else None
        gold_score = scores[gold_idx] if gold_idx_l else None
        if not correct:
            e = p.text1.text, p.text2.text, pred_item.text
            yield e
예제 #2
0
def lcs(text_a, text_b, debug=False):
    if not text_a or not text_b:
        return 0, []
    a_len = len(text_a)
    b_len = len(text_b)

    lcs_val = {}
    # lcs_arr[i,j] = length of lcs of a[:i], b[:j]

    # lcs_arr[i,j] = max(
    #   1) lcs_arr[i-1, j-1] + 1  if a[i]==b[j]
    #   2) lcs_arr[i-1, j]
    #   3) lcs_arr[i, j-1]

    lcs_type = {}

    def get_arr_val(i, j):
        if i < 0 or j < 0:
            return 0
        else:
            return lcs_val[(i, j)]

    for i in range(a_len):
        for j in range(b_len):

            if text_a[i] == text_b[j]:
                opt1 = get_arr_val(i - 1, j - 1) + 1
            else:
                opt1 = -1

            opt2 = get_arr_val(i - 1, j)
            opt3 = get_arr_val(i, j - 1)
            opt_list = [opt1, opt2, opt3]
            idx = get_max_idx(opt_list)
            opt = idx + 1
            lcs_type[i, j] = opt
            lcs_val[i, j] = opt_list[idx]

    i = a_len - 1
    j = b_len - 1
    match_log = []
    while i >= 0 and j >= 0:
        opt = lcs_type[i, j]
        if opt == 1:
            match_log.append((i, j))
            i, j = i - 1, j - 1
        elif opt == 2:
            i, j = i - 1, j
        elif opt == 3:
            i, j = i, j - 1
        else:
            assert False

    prev_j = -1
    return lcs_val[a_len - 1, b_len - 1], match_log[::-1]
예제 #3
0
파일: eval.py 프로젝트: clover3/Chair
def run_eval(split, scorer: Callable[[Passage, List[Passage]],
                                     List[NamedNumber]],
             condition: EvalCondition):
    problems, candidate_d = prepare_eval_data(split)
    debug = False

    problems = problems[:100]
    payload: List[Passage] = get_eval_payload_from_dp(problems)
    correctness = []
    fail_type_count = Counter()
    for query, problem in zip(payload, problems):
        p = problem
        candidate_ids: List[ArguDataID] = retrieve_candidate(
            query, split, condition)
        candidate = list([candidate_d[x] for x in candidate_ids])
        scores: List[NamedNumber] = scorer(query, candidate)
        best_idx = get_max_idx(scores)
        pred_item: Passage = candidate[best_idx]
        gold_id = p.text2.id
        pred_id = pred_item.id
        correct = gold_id == pred_id
        content_equal = (p.text2.text == pred_item.text)
        correct = correct or content_equal
        gold_idx_l = list(
            [idx for idx, c in enumerate(candidate) if c.id == gold_id])
        gold_idx = gold_idx_l[0] if gold_idx_l else None
        gold_score = scores[gold_idx] if gold_idx_l else None

        if not correct and debug:
            print("-------------------", correct, content_equal)
            print("QUERY:", p.text1.id)
            print(p.text1.text)
            print("GOLD:", p.text2.id)
            print(p.text2.text)
            print("PRED:", pred_item.id)
            # print(scores[best_idx].name)
            print(pred_item.text)
            print("RATIONALE: ", scores[best_idx].__float__(),
                  scores[best_idx].name)
            if gold_score is not None:
                print("GOLD RATIONALE:", gold_score.__float__(),
                      gold_score.name)

            t = failure_type(p.text2.id.id, pred_item.id.id)
            fail_type_count[t] += 1

        correctness.append(correct)
    avg_p_at_1 = average(correctness)
    #print(fail_type_count)
    return avg_p_at_1
예제 #4
0
    def get_ranks(self, qids, scorer: SegScorer):
        missing_cnt = 0
        missing_doc_qid = []

        rel_rank_list = []

        for qid in TEL(qids):
            if qid not in self.resource.get_doc_for_query_d():
                # assert not self.resource.query_in_qrel(qid)
                continue

            query_text = self.resource.get_query_text(qid)
            stemmed_tokens_d = self.resource.get_stemmed_tokens_d(qid)
            entries = []
            for doc_id in self.resource.get_doc_for_query_d()[qid]:
                label = self.resource.get_label(qid, doc_id)
                try:
                    stemmed_title_tokens, stemmed_body_tokens_list = stemmed_tokens_d[
                        doc_id]
                    scores: List[float]\
                        = scorer.get_scores(query_text,
                            stemmed_title_tokens,
                            stemmed_body_tokens_list,
                            )
                    max_score = max(scores) if scores else -999
                    if scores:
                        max_idx = get_max_idx(scores)
                        print(max_idx, len(scores))
                    entries.append((doc_id, max_score, label))
                except KeyError:
                    missing_cnt += 1
                    missing_doc_qid.append(qid)

            entries.sort(key=lambda x: x[1], reverse=True)
            if not entries:
                continue

            rel_rank = -1
            for rank, (doc_id, score, label) in enumerate(entries):
                if label:
                    rel_rank = rank
                    break
            rel_rank_list.append(rel_rank)
        return rel_rank_list
예제 #5
0
    def encode(
            self, query_text, stemmed_title_tokens: List[str],
            stemmed_body_tokens_list: List[List[str]],
            bert_title_tokens: List[str],
            bert_body_tokens_list: List[List[str]]) -> List[Tuple[List, List]]:

        # Title and body sentences are trimmed to 64 * 5 chars
        # score each sentence based on bm25_module
        stemmed_query_tokens = self.tokenize_stem(query_text)
        q_tf = Counter(stemmed_query_tokens)
        assert len(stemmed_body_tokens_list) == len(bert_body_tokens_list)

        stemmed_body_tokens_list = regroup_sent_list(stemmed_body_tokens_list,
                                                     4)
        bert_body_tokens_list = regroup_sent_list(bert_body_tokens_list, 4)

        def get_score(sent_idx):
            if self.include_title:
                tokens = stemmed_title_tokens + stemmed_body_tokens_list[
                    sent_idx]
            else:
                tokens = stemmed_body_tokens_list[sent_idx]

            doc_tf = Counter(tokens)
            return self.bm25_module.score_inner(q_tf, doc_tf)

        bert_query_tokens = self.bert_tokenize(query_text)
        if stemmed_body_tokens_list:
            seg_scores = lmap(get_score, range(len(stemmed_body_tokens_list)))
            max_idx = get_max_idx(seg_scores)
            content_len = self.max_seq_length - 3 - len(bert_query_tokens)
            second_tokens = bert_body_tokens_list[max_idx][:content_len]
        else:
            second_tokens = []
        out_tokens = ["[CLS]"] + bert_query_tokens + [
            "[SEP]"
        ] + second_tokens + ["[SEP]"]
        segment_ids = [0] * (len(bert_query_tokens) +
                             2) + [1] * (len(second_tokens) + 1)
        entry = out_tokens, segment_ids
        return [entry]
예제 #6
0
def main():
    rlp = "C:\\work\\Code\\Chair\\output\\clue_counter_arg\\ranked_list.txt"
    html_dir = "C:\\work\\Code\\Chair\\output\\clue_counter_arg\\docs"

    grouped: Dict[str,
                  List[Tuple[str,
                             str]]] = load_all_docs_cleaned(rlp, html_dir)
    tids_score_dict = get_f5_tids_score_d_from_svm()

    def get_score(text):
        if text in tids_score_dict:
            return tids_score_dict[text]
        else:
            return -10000

    class AnalyezedDoc(NamedTuple):
        doc_id: str
        text: str
        score: float
        max_score_sent: str

    for query, entries in grouped.items():
        ad_list = []
        for doc_id, text in entries:
            all_text_list = [text] + sent_tokenize(text)
            scores = lmap(get_score, all_text_list)
            max_idx_ = get_max_idx(scores)
            max_score = scores[max_idx_]
            ad = AnalyezedDoc(doc_id, text, max_score, all_text_list[max_idx_])
            ad_list.append(ad)

        ad_list.sort(key=lambda x: x.score, reverse=True)
        print("QID: ", query)
        for ad in ad_list[:5]:
            rows = [['doc_id', ad.doc_id], ['score', ad.score],
                    ['max_sent', ad.max_score_sent], ['fulltext', ad.text]]
            print("-----")
            print_table(rows)