示例#1
0
def main():
    todo: List[Tuple[QueryID, MSMarcoDoc]] = get_todo()
    msmarco_passage_qrel_path = at_data_dir("msmarco", "qrels.train.tsv")
    passage_qrels: QRelsDict = load_qrels_structured(msmarco_passage_qrel_path)

    try:
        passage_dict = load_from_pickle("msmarco_passage_doc_analyze_passage_dict")
    except FileNotFoundError:
        passage_dict = load_passage_dict(todo, passage_qrels)
    doc_queries = dict(load_train_queries())

    itr: Iterable[Tuple[str, MSMarcoDoc, JoinedPassage]] = join_doc_passage(todo, passage_qrels, passage_dict)
    ##
    for qid, doc, passage in itr:
        query_text = doc_queries[QueryID(qid)]
        print('query', qid, query_text)
        prev = doc.body[:passage.loc]
        passage_text = passage.text
        tail = doc.body[passage.loc + len(passage_text):]
        print("-----")
        print(prev)
        print(">>>")
        print(passage_text)
        print("<<<")
        print(tail)
        print("-----")
示例#2
0
def main():
    qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
    judgement = load_qrels_structured(qrel_path)

    def is_correct(query: QCKQuery, candidate: QCKCandidate):
        qid = query.query_id
        doc_part_id = candidate.id
        doc_id = "_".join(doc_part_id.split("_")[:-1])
        if qid not in judgement:
            return 0
        d = judgement[qid]
        if doc_id in d:
            return d[doc_id]
        else:
            return 0

    qk_candidate: List[QKUnit] = load_from_pickle(
        "robust_on_clueweb_qk_candidate")
    candidate_dict: \
        Dict[str, List[QCKCandidateI]] = load_candidate_all_passage_from_qrel(256)
    generator = QCKInstanceGenerator(candidate_dict, is_correct)
    num_jobs = 250

    def worker_factory(out_dir):
        worker = QCKWorker(qk_candidate, generator, out_dir)
        return worker

    ##
    job_name = "robust_qck_6"
    runner = JobRunner(job_man_dir, num_jobs, job_name, worker_factory)
    runner.start()
示例#3
0
 def __init__(self, query_type="desc", neg_k=1000):
     qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
     self.judgement = load_qrels_structured(qrel_path)
     self.queries = load_robust_04_query(query_type)
     self.tokenizer = get_tokenizer()
     self.galago_rank = load_bm25_best()
     self.neg_k = neg_k
示例#4
0
def main():
    qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
    judgement = load_qrels_structured(qrel_path)

    def is_correct(query: QCKQuery, candidate: QCKCandidate):
        qid = query.query_id
        doc_id = candidate.id
        if qid not in judgement:
            return 0
        d = judgement[qid]
        label = 1 if doc_id in d and d[doc_id] > 0 else 0
        return label

    qk_candidate: List[QKUnit] = load_from_pickle(
        "robust_on_clueweb_qk_candidate_filtered")

    candidate_dict = load_cache("candidate_for_robust_qck_7")
    if candidate_dict is None:
        candidate_dict: \
            Dict[str, List[QCKCandidateI]] = get_candidate_all_passage_w_samping()
        save_to_pickle(candidate_dict, "candidate_for_robust_qck_7")

    generator = QCKInstanceGenerator(candidate_dict, is_correct)
    num_jobs = 250

    def worker_factory(out_dir):
        worker = QCKWorker(qk_candidate, generator, out_dir)
        return worker

    ##
    job_name = "robust_qck_10"
    runner = JobRunner(job_man_dir, num_jobs, job_name, worker_factory)
    runner.start()
示例#5
0
 def __init__(self, encoder, max_seq_length):
     self.data = load_robust_tokens_for_train()
     assert len(self.data) == 174787
     qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
     self.judgement = load_qrels_structured(qrel_path)
     self.max_seq_length = max_seq_length
     self.queries = load_robust04_title_query()
     self.encoder = encoder
     self.tokenizer = get_tokenizer()
示例#6
0
 def __init__(self, encoder, max_seq_length, query_type="title"):
     self.data = self.load_tokens()
     qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
     self.judgement = load_qrels_structured(qrel_path)
     self.max_seq_length = max_seq_length
     self.queries = load_robust_04_query(query_type)
     self.encoder = encoder
     self.tokenizer = get_tokenizer()
     self.galago_rank = load_bm25_best()
示例#7
0
 def __init__(self, doc_max_length, query_type="title", neg_k=1000, pos_only=True):
     self.data = self.load_tokens()
     qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
     self.judgement = load_qrels_structured(qrel_path)
     self.doc_max_length = doc_max_length
     self.queries = load_robust_04_query(query_type)
     self.tokenizer = get_tokenizer()
     self.galago_rank = load_bm25_best()
     self.neg_k = neg_k
     self.pos_only = pos_only
示例#8
0
def main(config):
    info_dir = config['info_path']
    prediction_file = config['pred_path']

    f_handler = get_format_handler("qck")
    info = load_combine_info_jsons(info_dir, f_handler.get_mapping(),
                                   f_handler.drop_kdp())
    data: List[Dict] = join_prediction_with_info(prediction_file, info,
                                                 ["data_id", "logits"])
    out_entries: List[QCKOutEntry] = lmap(QCKOutEntry.from_dict, data)
    qrel: Dict[str, Dict[str,
                         int]] = load_qrels_structured(config['qrel_path'])

    def get_label(query_id, candi_id):
        if candi_id in qrel[query_id]:
            return qrel[query_id][candi_id]
        else:
            return 0

    def logit_to_score_softmax(logit):
        return scipy.special.softmax(logit)[1]

    grouped: Dict[str,
                  List[QCKOutEntry]] = group_by(out_entries,
                                                lambda x: x.query.query_id)
    for query_id, items in grouped.items():
        raw_kdp_list = [(x.kdp.doc_id, x.kdp.passage_idx) for x in items]
        kdp_list = unique_list(raw_kdp_list)

        raw_candi_id_list = [x.candidate.id for x in items]
        candi_id_list = unique_list(raw_candi_id_list)

        logit_d = {(x.candidate.id, (x.kdp.doc_id, x.kdp.passage_idx)):
                   x.logits
                   for x in items}
        labels = [get_label(query_id, candi_id) for candi_id in candi_id_list]
        head_row0 = [" "] + labels
        head_row1 = [" "] + candi_id_list
        rows = [head_row0, head_row1]
        for kdp_sig in kdp_list:
            row = [kdp_sig]
            for candi_id in candi_id_list:
                try:
                    score = logit_to_score_softmax(logit_d[candi_id, kdp_sig])
                    score_str = "{0:.2f}".format(score)
                except KeyError:
                    score_str = "-"
                row.append(score_str)
            rows.append(row)

        print(query_id)
        print_table(rows)
示例#9
0
    def __init__(self, encoder, max_seq_length, query_type,
                 target_selection_fn: Callable[[str, str, List], List[int]]):
        self.data = self.load_tokens()
        qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
        self.judgement = load_qrels_structured(qrel_path)
        self.max_seq_length = max_seq_length
        self.queries = load_robust_04_query(query_type)
        self.encoder = encoder
        self.tokenizer = get_tokenizer()
        self.galago_rank = load_bm25_best()

        self.target_selection_fn: Callable[[str, str, List],
                                           List[int]] = target_selection_fn
示例#10
0
def main():
    rlg_proposed_tfidf = load_ranked_list_grouped(sys.argv[1])
    rlg_proposed_bm25 = load_ranked_list_grouped(sys.argv[2])
    rlg_bert_tfidf = load_ranked_list_grouped(sys.argv[3])
    qrel: QRelsDict = load_qrels_structured(sys.argv[4])

    # TODO
    #  Q1 ) Is the set of document different?
    #  Q2 ) Say 2 (BM25) is better than 1 (tf-idf),
    #    1. X=0, in 2 not in 1
    #    2. X=1, in 2 not in 1
    #    3. X=0, in 1 not in 2   -> FP prediction that BERT(baseline) misses
    #    4. X=1, in 1 not in 2
    cnt = 0
    for q in rlg_proposed_tfidf:
        entries1 = rlg_proposed_tfidf[q]
        entries2 = rlg_proposed_bm25[q]
        entries3 = rlg_bert_tfidf[q]
        e3_d = {e.doc_id: e for e in entries3}

        def get_doc_set(entries):
            return set(map(TrecRankedListEntry.get_doc_id, entries))

        docs1 = get_doc_set(entries1)
        docs2 = get_doc_set(entries2)

        d = qrel[q]
        rows = [[q]]
        rows.append([
            'doc_id', 'label', 'in_bm25', '1_rank', '1_score', '3_rank',
            '3_score'
        ])
        for e in entries1:
            label = d[e.doc_id] if e.doc_id in d else 0
            #if e.doc_id not in docs2:
            if True:
                # Case 3
                predict_binary = e.rank < 20
                try:
                    e3 = e3_d[e.doc_id]
                    row = [
                        e.doc_id, label, e.doc_id in docs2, e.rank, e.score,
                        e3.rank, e3.score
                    ]
                    rows.append(row)
                except KeyError:
                    assert cnt == 0
                    cnt += 1

        if len(rows) > 2:
            print_table(rows)
示例#11
0
def main(pred_file_path: str,
         info_file_path: str,
         info_file_path2: str,
         save_name: str,
                        input_type: str,
                        qrel_path: str,
                        ):

    judgement = load_qrels_structured(qrel_path)
    def get_label(key):
        query_id, doc_id = key
        try:
            return judgement[query_id][doc_id]
        except KeyError:
            return 0

    f_handler = get_format_handler(input_type)
    info: Dict = load_combine_info_jsons(info_file_path, f_handler.get_mapping(), f_handler.drop_kdp())

    info2: Dict = load_combine_info_jsons(info_file_path2, f_handler.get_mapping(), f_handler.drop_kdp())
    doc_length = get_doc_length_info(info2)
    key_logit = "logits"

    data: List[Dict] = join_prediction_with_info(pred_file_path, info, ["data_id", key_logit])

    grouped = group_by(data, f_handler.get_pair_id)

    cnt = Counter()
    for key, entries in grouped.items():
        if not get_label(key):
            continue
        seg_groups = {}
        for e in entries:
            probs = scipy.special.softmax(e['logits'])[:, 1]
            seg_groups[e['idx']] = probs

        indices = list(seg_groups.keys())
        indices.sort()
        assert max(indices) == len(indices) - 1
        all_probs = []
        for seg_group_idx in seg_groups.keys():
            all_probs.extend(seg_groups[seg_group_idx])

        num_seg = doc_length[key]
        max_idx = np.argmax(all_probs[:num_seg])


        cnt[(max_idx, num_seg)] += 1

    save_to_pickle(cnt, save_name)
示例#12
0
 def __init__(self,
              encoder,
              max_seq_length,
              score_d,
              query_type="title",
              neg_k=1000):
     self.data = self.load_tokens()
     qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
     self.score_d: Dict[str, List[float]] = score_d
     self.judgement = load_qrels_structured(qrel_path)
     self.max_seq_length = max_seq_length
     self.queries = load_robust_04_query(query_type)
     self.encoder = encoder
     self.tokenizer = get_tokenizer()
     self.galago_rank = load_bm25_best()
     self.neg_k = neg_k
     self.n_seg_per_doc = 4
示例#13
0
def load_candidate_all_passage_from_qrel(
        max_seq_length,
        max_passage_per_doc=10) -> Dict[str, List[QCKCandidateWToken]]:
    qrel_path = os.path.join(data_path, "robust", "qrels.rob04.txt")
    judgement: Dict[str, Dict] = load_qrels_structured(qrel_path)

    candidate_doc_ids = {}
    for query_id in judgement.keys():
        judge_entries = judgement[query_id]
        doc_ids = list(judge_entries.keys())
        candidate_doc_ids[query_id] = doc_ids

    token_data = load_robust_tokens_for_train()

    return load_candidate_all_passage_inner(candidate_doc_ids, token_data,
                                            max_seq_length,
                                            max_passage_per_doc, 9999)
示例#14
0
    def __init__(self,
                 encoder,
                 max_seq_length_per_inst,
                 num_doc_per_inst,
                 num_seg_per_inst,
                 query_type="title",
                 neg_k=1000):
        self.data = self.load_tokens()
        qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
        self.judgement = load_qrels_structured(qrel_path)
        self.max_seq_length = max_seq_length_per_inst
        self.queries = load_robust_04_query(query_type)
        self.num_doc_per_inst = num_doc_per_inst
        self.num_seg_per_inst = num_seg_per_inst

        self.all_segment_encoder = encoder
        self.tokenizer = get_tokenizer()
        self.galago_rank = load_bm25_best()
        self.neg_k = neg_k
示例#15
0
def main():
    l1: Dict[str,
             List[TrecRankedListEntry]] = load_ranked_list_grouped(sys.argv[1])
    qrel: Dict[str, Dict[str, int]] = load_qrels_structured(sys.argv[2])

    threshold_list = []
    ptr = 0.0
    for i in range(10):
        v = ptr + i / 10
        threshold_list.append(v)
    res = []
    for t in threshold_list:
        prec_list = []
        recall_list = []
        for query, ranked_list in l1.items():
            gold_dict = qrel[query] if query in qrel else {}
            gold_docs = []
            for doc_id, label in gold_dict.items():
                if label:
                    gold_docs.append(doc_id)

            pred_list = []
            for e in ranked_list:
                if e.score > t:
                    pred_list.append(e.doc_id)

            common = set(gold_docs).intersection(set(pred_list))
            tp = len(common)
            prec = tp / len(pred_list) if pred_list else 1
            recall = tp / len(gold_docs) if gold_docs else 1
            prec_list.append(prec)
            recall_list.append(recall)

        prec = average(prec_list)
        recall = average(recall_list)
        f1 = get_f1(prec, recall)
        res.append((t, prec, recall, f1))

    for t, prec, recall, f1 in res:
        print(t, prec, recall, f1)
示例#16
0
 def __init__(self,
              encoder,
              max_seq_length,
              scores,
              query_type="title",
              target_selection="best"):
     self.data = self.load_tokens()
     qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
     self.judgement = load_qrels_structured(qrel_path)
     self.max_seq_length = max_seq_length
     self.queries = load_robust_04_query(query_type)
     self.encoder = encoder
     self.tokenizer = get_tokenizer()
     self.galago_rank = load_bm25_best()
     self.scores: Dict[Tuple[str, str, int], float] = scores
     self.get_target_indices: Callable[[], List[int]] = {
         'best': get_target_indices_get_best,
         'all': get_target_indices_all,
         'first_and_best': get_target_indices_first_and_best,
         'best_or_over_09': get_target_indices_best_or_over_09,
         'random_over_09': get_target_indices_random_over_09
     }[target_selection]
示例#17
0
def get_candidate_all_passage_w_samping(
        max_seq_length=256, neg_k=1000) -> Dict[str, List[QCKCandidateWToken]]:
    qrel_path = os.path.join(data_path, "robust", "qrels.rob04.txt")
    galago_rank = load_bm25_best()
    tokens_d = load_robust_tokens_for_train()
    tokens_d.update(load_robust_tokens_for_predict(4))
    queries = load_robust04_title_query()
    tokenizer = get_tokenizer()
    judgement: Dict[str, Dict] = load_qrels_structured(qrel_path)
    out_d: Dict[str, List[QCKCandidateWToken]] = {}
    for query_id in judgement.keys():
        if query_id not in judgement:
            continue
        query = queries[query_id]
        query_tokens = tokenizer.tokenize(query)

        judge_entries = judgement[query_id]
        doc_ids = set(judge_entries.keys())

        ranked_list = galago_rank[query_id]
        ranked_list = ranked_list[:neg_k]
        doc_ids.update([e.doc_id for e in ranked_list])

        candidate = []
        for doc_id in doc_ids:
            tokens = tokens_d[doc_id]
            for idx, passage in enumerate(enum_passage(tokens,
                                                       max_seq_length)):
                if idx == 0:
                    include = True
                else:
                    include = random.random() < 0.1

                if include:
                    c = QCKCandidateWToken(doc_id, "", passage)
                    candidate.append(c)

        out_d[query_id] = candidate
    return out_d
示例#18
0
def main2():
    rlg_proposed_tfidf = load_ranked_list_grouped(sys.argv[1])
    rlg_proposed_bm25 = load_ranked_list_grouped(sys.argv[2])
    rlg_bert_tfidf = load_ranked_list_grouped(sys.argv[3])
    qrel: QRelsDict = load_qrels_structured(sys.argv[4])

    flat_etr1 = []
    flat_etr3 = []
    for q in rlg_proposed_tfidf:
        entries1 = rlg_proposed_tfidf[q]
        entries2 = rlg_proposed_bm25[q]
        entries3 = rlg_bert_tfidf[q]

        def get_doc_set(entries):
            return set(map(TrecRankedListEntry.get_doc_id, entries))

        docs2 = get_doc_set(entries2)

        d = qrel[q]

        def reform(entries):
            es = list([e for e in entries if e.doc_id not in docs2])

            new_entries = []
            for idx, e in enumerate(es):
                new_entries.append(
                    TrecRankedListEntry(e.query_id, e.doc_id, idx, e.score,
                                        e.run_name))
            return new_entries

        etr1 = reform(entries1)
        flat_etr1.extend(etr1)
        etr3 = reform(entries3)
        flat_etr3.extend(etr3)

    write_trec_ranked_list_entry(flat_etr1, "bm25equi_proposed.txt")
    write_trec_ranked_list_entry(flat_etr3, "bm25equi_bert.txt")
示例#19
0
def main():
    split = "train"
    resource = ProcessedResource10docMulti(split)

    query_group: List[List[QueryID]] = load_query_group(split)
    msmarco_passage_qrel_path = at_data_dir("msmarco", "qrels.train.tsv")
    passage_qrels: QRelsDict = load_qrels_structured(msmarco_passage_qrel_path)

    qids = query_group[0]
    qids = qids[:100]
    pickle_name = "msmarco_passage_doc_analyze_passage_dict_evidence_loc"
    try:
        passage_dict = load_from_pickle(pickle_name)
    except FileNotFoundError:
        print("Reading passages...")
        passage_dict = get_passages(qids, passage_qrels)
        save_to_pickle(passage_dict, pickle_name)
    def get_rel_doc_id(qid):
        if qid not in resource.get_doc_for_query_d():
            raise KeyError
        for doc_id in resource.get_doc_for_query_d()[qid]:
            label = resource.get_label(qid, doc_id)
            if label:
                return doc_id
        raise KeyError

    def translate_token_idx_to_sent_idx(stemmed_body_tokens_list, loc_in_body):
        acc = 0
        for idx, tokens in enumerate(stemmed_body_tokens_list):
            acc += len(tokens)
            if loc_in_body < acc:
                return idx
        return -1

    pc_tokenize = PCTokenizer()
    bert_tokenizer = get_tokenizer()

    for qid in qids:
        try:
            doc_id = get_rel_doc_id(qid)
            stemmed_tokens_d = resource.get_stemmed_tokens_d(qid)
            stemmed_title_tokens, stemmed_body_tokens_list = stemmed_tokens_d[doc_id]
            rel_passages = list([passage_id for passage_id, score in passage_qrels[qid].items() if score])
            success = False
            found_idx = -1
            for rel_passage_id in rel_passages:
                passage_text = passage_dict[rel_passage_id].strip()
                passage_tokens = pc_tokenize.tokenize_stem(passage_text)
                stemmed_body_tokens_flat = lflatten(stemmed_body_tokens_list)
                n, log = lcs(passage_tokens, stemmed_body_tokens_flat, True)
                if len(passage_tokens) > 4 and n > len(passage_tokens) * 0.7 and n > 0:
                    success = True
                    _, loc_in_body = log[0]

                    sent_idx = translate_token_idx_to_sent_idx(stemmed_body_tokens_list, loc_in_body)
                    prev = stemmed_body_tokens_flat[:loc_in_body]

                    loc_by_bert_tokenize = len(bert_tokenizer.tokenize(" ".join(prev)))
                    print(sent_idx, loc_in_body, loc_by_bert_tokenize, len(stemmed_body_tokens_list))
                    found_idx = sent_idx
            if not success:
                print("Not found. doc_lines={} passage_len={}".format(len(stemmed_body_tokens_list), len(passage_tokens)))

        except KeyError:
            pass
示例#20
0
def load_robust04_qrels():
    path = os.path.join(robust_path, "qrels.rob04.txt")
    return load_qrels_structured(path)
示例#21
0
def main():
    prediction_file_path = at_output_dir("robust", "rob_dense2_pred.score")
    info_file_path = at_job_man_dir1("robust_predict_desc_128_step16_2_info")
    queries: Dict[str, str] = load_robust_04_query("desc")
    tokenizer = get_tokenizer()
    query_token_len_d = {}
    for qid, q_text in queries.items():
        query_token_len_d[qid] = len(tokenizer.tokenize(q_text))
    step_size = 16
    window_size = 128
    out_entries: List[AnalyzedDoc] = token_score_by_ablation(
        info_file_path, prediction_file_path, query_token_len_d, step_size,
        window_size)

    qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
    judgement_d = load_qrels_structured(qrel_path)

    html = HtmlVisualizer("robust_desc_128_step16_2.html", use_tooltip=True)

    tprint("loading tokens pickles")
    tokens_d: Dict[str, List[str]] = load_pickle_from(
        os.path.join(sydney_working_dir, "RobustPredictTokens3", "1"))
    tprint("Now printing")
    n_printed = 0

    def transform(x):
        return 3 * (math.pow(x - 0.5, 3) + math.pow(0.5, 3))

    n_pos = 0
    n_neg = 0
    for e in out_entries:
        max_score: float = max(
            lmap(SegmentScorePair.get_max_score,
                 flatten(e.token_info.values())))
        if max_score < 0.6:
            if n_neg > n_pos:
                continue
            else:
                n_neg += 1
                pass
        else:
            n_pos += 1

        n_printed += 1
        if n_printed > 500:
            break

        doc_tokens: List[str] = tokens_d[e.doc_id]
        score_len = max(e.token_info.keys()) + 1
        judgement: Dict[str, int] = judgement_d[e.query_id]
        label = judgement[e.doc_id]

        if not len(doc_tokens) <= score_len < len(doc_tokens) + window_size:
            print("doc length : ", len(doc_tokens))
            print("score len:", score_len)
            print("doc length +step_size: ", len(doc_tokens) + step_size)
            continue

        row = []
        q_text = queries[e.query_id]
        html.write_paragraph("qid: " + e.query_id)
        html.write_paragraph("q_text: " + q_text)
        html.write_paragraph("Pred: {0:.2f}".format(max_score))
        html.write_paragraph("Label: {0:.2f}".format(label))

        for idx in range(score_len):
            token = doc_tokens[idx] if idx < len(doc_tokens) else '[-]'
            token_info: List[SegmentScorePair] = e.token_info[idx]
            full_scores: List[float] = lmap(SegmentScorePair.get_score_diff,
                                            token_info)

            full_score_str = " ".join(lmap(two_digit_float, full_scores))
            # 1 ~ -1
            score = average(full_scores)
            if score > 0:
                color = "B"
            else:
                color = "R"
            normalized_score = transform(abs(score)) * 200
            c = get_tooltip_cell(token, full_score_str)
            c.highlight_score = normalized_score
            c.target_color = color
            row.append(c)

        html.multirow_print(row, 16)
示例#22
0
文件: robust2.py 项目: clover3/Chair
def load_robust_qrel() -> Dict[str, Dict[str, int]]:
    qrel_path = os.path.join(robust_path, "qrels.rob04.txt")
    return load_qrels_structured(qrel_path)
示例#23
0
def load_qrels_for(year) -> QRelsDict:
    qrel_path = at_data_dir("clueweb", "{}.qrels.txt".format(year))
    return load_qrels_structured(qrel_path)
示例#24
0
def main():
    prediction_file_path = at_output_dir("robust", "rob_dense_pred.score")
    info_file_path = at_job_man_dir1("robust_predict_desc_128_step16_info")
    queries: Dict[str, str] = load_robust_04_query("desc")
    tokenizer = get_tokenizer()
    query_token_len_d = {}
    for qid, q_text in queries.items():
        query_token_len_d[qid] = len(tokenizer.tokenize(q_text))
    step_size = 16
    window_size = 128
    out_entries: List[DocTokenScore] = collect_token_scores(
        info_file_path, prediction_file_path, query_token_len_d, step_size,
        window_size)

    qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
    judgement_d = load_qrels_structured(qrel_path)

    html = HtmlVisualizer("robust_desc_128_step16.html", use_tooltip=True)

    tprint("loading tokens pickles")
    tokens_d: Dict[str, List[str]] = load_pickle_from(
        os.path.join(sydney_working_dir, "RobustPredictTokens3", "1"))
    tprint("Now printing")
    n_printed = 0

    def transform(x):
        return 3 * (math.pow(x - 0.5, 3) + math.pow(0.5, 3))

    for e in out_entries:
        max_score = e.max_segment_score()
        if max_score < 0.6:
            continue
        n_printed += 1
        if n_printed > 10:
            break
        doc_tokens: List[str] = tokens_d[e.doc_id]
        score_len = len(e.scores)
        judgement: Dict[str, int] = judgement_d[e.query_id]
        label = judgement[e.doc_id]

        if not len(doc_tokens) <= score_len < len(doc_tokens) + window_size:
            print("doc length : ", len(doc_tokens))
            print("score len:", score_len)
            print("doc length +step_size: ", len(doc_tokens) + step_size)
            raise IndexError

        row = []
        q_text = queries[e.query_id]
        html.write_paragraph("qid: " + e.query_id)
        html.write_paragraph("q_text: " + q_text)
        html.write_paragraph("Pred: {0:.2f}".format(max_score))
        html.write_paragraph("Label: {0:.2f}".format(label))

        for idx in range(score_len):
            token = doc_tokens[idx] if idx < len(doc_tokens) else '[-]'

            full_scores = e.full_scores[idx]
            full_score_str = " ".join(lmap(two_digit_float, full_scores))
            score = e.scores[idx]
            normalized_score = transform(score) * 200
            c = get_tooltip_cell(token, full_score_str)
            c.highlight_score = normalized_score
            row.append(c)

        html.multirow_print(row, 16)
示例#25
0
 def __init__(self):
     qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
     self.judgement = load_qrels_structured(qrel_path)
示例#26
0
def main():
    non_combine_ranked_list: Dict[
        str, List[TrecRankedListEntry]] = load_ranked_list_grouped(sys.argv[1])
    judgement: QRelsDict = load_qrels_structured(sys.argv[2])

    # TODO : Find the case where scoring first segment would get better score
    # TODO : Check if the FP, TP document has score in low range or high range
    # TODO : Check if FP

    def parse_doc_name(doc_name):
        tokens = doc_name.split("_")
        doc_id = "_".join(tokens[:-1])
        passage_idx = int(tokens[-1])
        return doc_id, passage_idx

    class Passage(NamedTuple):
        doc_id: str
        score: float
        passage_idx: int

        def get_doc_id(self):
            return self.doc_id

        def get_score(self):
            return self.score

    class Entry(NamedTuple):
        doc_id: str
        max_score: float
        max_passage_idx: int
        num_passage: int
        passages: List[Passage]

    for query_id, entries in non_combine_ranked_list.items():
        new_e_list = []
        for e in entries:
            doc_id, passage_idx = parse_doc_name(e.doc_id)
            new_e = Passage(doc_id, e.score, passage_idx)
            new_e_list.append(new_e)
        grouped: Dict[str, List[Passage]] = group_by(new_e_list, get_first)
        get_passage_idx = get_third
        get_score = get_second

        grouped_entries = []
        doc_id_to_entry = {}
        score_by_head: List[Tuple[str, float]] = []
        for doc_id, scored_passages in grouped.items():
            scored_passages.sort(key=get_passage_idx)
            first_seg_score = scored_passages[0].score
            score_by_head.append((doc_id, first_seg_score))

            scored_passages.sort(key=Passage.get_score, reverse=True)
            max_score = scored_passages[0].score
            max_passage_idx = scored_passages[0].passage_idx
            num_passage = len(scored_passages)
            e = Entry(doc_id, max_score, max_passage_idx, num_passage,
                      scored_passages)
            scored_passages.sort(key=get_passage_idx)
            grouped_entries.append(e)
            doc_id_to_entry[doc_id] = e

        rel_d = judgement[query_id]

        def is_relevant(doc_id):
            return doc_id in rel_d and rel_d[doc_id]

        score_by_head.sort(key=get_second, reverse=True)
        rel_rank_by_head = -1

        rel_rank = 0
        rel_doc_id = ""
        for rank, (doc_id, score) in enumerate(score_by_head):
            if is_relevant(doc_id):
                rel_rank_by_head = rank
                rel_doc_id = doc_id

        grouped_entries.sort(key=lambda x: x.max_score, reverse=True)
        rel_rank_by_max = -1
        for rank, e in enumerate(grouped_entries):
            if is_relevant(e.doc_id):
                rel_rank_by_max = rank

        def get_passage_score_str(passages: List[Passage]):
            passage_scores: List[float] = lmap(Passage.get_score, passages)
            scores_str = " ".join(map(two_digit_float, passage_scores))
            return scores_str

        if rel_rank_by_head < rel_rank_by_max:
            print()
            print("< Relevant document >")
            print("Rank by head", rel_rank_by_head)
            print("Rank by max", rel_rank_by_max)
            rel_entry = doc_id_to_entry[rel_doc_id]
            print(get_passage_score_str(rel_entry.passages))
            print("Num passages", rel_entry.num_passage)

            for rank, entry in enumerate(grouped_entries):
                if len(
                        entry.passages
                ) > 1 and entry.doc_id != rel_doc_id and rank < rel_rank_by_max:
                    print("< False positive document >")
                    print("Rank by max:", rank)
                    print("doc_id", entry.doc_id)
                    print("Num passages", entry.num_passage)
                    print("max_score", entry.max_score)
                    print("Passages scores: ",
                          get_passage_score_str(entry.passages))
                    break
示例#27
0
 def __init__(self):
     super(RobustPreprocessTrain, self).__init__()
     qrel_path = "/home/youngwookim/Downloads/rob04-desc/qrels.rob04.txt"
     self.judgement = load_qrels_structured(qrel_path)