예제 #1
0
def transform(
    t: Tuple[ArguDataPoint, Passage,
             bool]) -> Tuple[QCKQuery, QCKCandidate, bool]:
    problem, candidate, is_correct = t
    return QCKQuery(problem.text1.id.id, problem.text1.text), \
           QCKCandidate(candidate.id.id, candidate.text), \
           is_correct
예제 #2
0
def convert(info):
    new_info = {
        'query': QCKQuery(str(info['qid']), ""),
        'candidate': QCKCandidate(str(info['cid']), ""),
        'kdp': KnowledgeDocumentPart("", 0, 0, []),
    }
    return new_info
예제 #3
0
    def generate(self, query_list, data_id_manager) -> Iterator[ClassificationInstanceWDataID]:
        neg_k = self.neg_k
        for query_id in query_list:
            if query_id not in self.judgement:
                continue

            qck_query = QCKQuery(query_id, "")
            judgement = self.judgement[query_id]
            query = self.queries[query_id]
            query_tokens = self.tokenizer.tokenize(query)

            ranked_list = self.galago_rank[query_id]
            ranked_list = ranked_list[:neg_k]

            target_docs = set(judgement.keys())
            target_docs.update([e.doc_id for e in ranked_list])
            print("Total of {} docs".format(len(target_docs)))

            for doc_id in target_docs:
                tokens = self.data[doc_id]
                passage_list = self.encoder.encode(query_tokens, tokens)
                label = 1 if doc_id in judgement and judgement[doc_id] > 0 else 0
                if not label:
                    continue
                candidate = QCKCandidate(doc_id, "")
                for idx, (tokens, seg_ids) in enumerate(passage_list):
                    info = {
                        'query': get_light_qckquery(qck_query),
                        'candidate': get_light_qckcandidate(candidate),
                        'idx': idx,
                    }
                    data_id = data_id_manager.assign(info)
                    inst = ClassificationInstanceWDataID(tokens, seg_ids, label, data_id)
                    yield inst
예제 #4
0
    def generate(self, query_list, data_id_manager) -> Iterator[QueryDocInstance]:
        neg_k = self.neg_k
        for query_id in query_list:
            if query_id not in self.judgement:
                continue

            qck_query = QCKQuery(query_id, "")
            judgement = self.judgement[query_id]
            query = self.queries[query_id]
            query_tokens = self.tokenizer.tokenize(query)

            ranked_list = self.galago_rank[query_id]
            ranked_list = ranked_list[:neg_k]

            target_docs = set(judgement.keys())
            target_docs.update([e.doc_id for e in ranked_list])
            print("Total of {} docs".format(len(target_docs)))

            for doc_id in target_docs:
                tokens = self.data[doc_id][:self.doc_max_length]
                label = 1 if doc_id in judgement and judgement[doc_id] > 0 else 0
                if self.pos_only and not label:
                    continue
                candidate = QCKCandidate(doc_id, "")
                info = {
                    'query': get_light_qckquery(qck_query),
                    'candidate': get_light_qckcandidate(candidate),
                    'q_term_len': len(query_tokens),
                }
                data_id = data_id_manager.assign(info)
                inst = QueryDocInstance(query_tokens, tokens, label, data_id)
                yield inst
예제 #5
0
파일: eval.py 프로젝트: clover3/Chair
def baseline_eval():
    split = "train"
    k = 30
    c_d: Dict[str, List[QCKCandidate]] = get_candidate_w_score(split, k)

    output_ranked_list = []
    for qid, ranked_list in c_d.items():
        output_ranked_list.append((QCKQuery(qid, ""), ranked_list))

    scores = get_precision_recall(output_ranked_list)
    print(scores)
예제 #6
0
    def generate(self, data_id_manager, qids):
        missing_cnt = 0
        success_docs = 0
        missing_doc_qid = []
        for qid in qids:
            if qid not in self.resource.get_doc_for_query_d():
                # assert not self.resource.query_in_qrel(qid)
                continue

            query_text = self.resource.get_query_text(qid)
            bert_tokens_d = self.resource.get_bert_tokens_d(qid)
            stemmed_tokens_d = self.resource.get_stemmed_tokens_d(qid)
            for doc_id in self.resource.get_doc_for_query_d()[qid]:
                label = self.resource.get_label(qid, doc_id)
                try:
                    bert_title_tokens, bert_body_tokens_list = bert_tokens_d[
                        doc_id]
                    stemmed_title_tokens, stemmed_body_tokens_list = stemmed_tokens_d[
                        doc_id]
                    insts: List[Tuple[List, List]]\
                        = self.encoder.encode(
                            query_text,
                            stemmed_title_tokens,
                            stemmed_body_tokens_list,
                            bert_title_tokens,
                            bert_body_tokens_list
                    )

                    for passage_idx, passage in enumerate(insts):
                        tokens_seg, seg_ids = passage
                        assert type(tokens_seg[0]) == str
                        assert type(seg_ids[0]) == int
                        data_id = data_id_manager.assign({
                            'query':
                            QCKQuery(qid, ""),
                            'candidate':
                            QCKCandidate(doc_id, ""),
                            'passage_idx':
                            passage_idx,
                        })
                        inst = ClassificationInstanceWDataID(
                            tokens_seg, seg_ids, label, data_id)
                        yield inst
                    success_docs += 1
                except KeyError:
                    missing_cnt += 1
                    missing_doc_qid.append(qid)
                    if missing_cnt > 10:
                        print(missing_doc_qid)
                        print("success: ", success_docs)
                        raise KeyError
예제 #7
0
파일: common.py 프로젝트: clover3/Chair
def get_qck_queries_all() -> List[QCKQuery]:
    pc_itr = enum_perspective_clusters()
    claim_text_d: Dict[int, str] = get_all_claim_d()

    query_list = []
    for pc in pc_itr:
        c_text = claim_text_d[pc.claim_id]
        pid = min(pc.perspective_ids)
        p_text = perspective_getter(pid)
        text = c_text + " " + p_text
        query = QCKQuery(get_pc_cluster_query_id(pc), text)
        query_list.append(query)

    return query_list
예제 #8
0
파일: common.py 프로젝트: clover3/Chair
def get_qck_queries(split) -> List[QCKQuery]:
    claim_ids = set(load_claim_ids_for_split(split))
    pc_itr = enum_perspective_clusters_for_split(split)
    claim_text_d: Dict[int, str] = get_all_claim_d()

    query_list = []
    for pc in pc_itr:
        if pc.claim_id in claim_ids:
            c_text = claim_text_d[pc.claim_id]
            pid = min(pc.perspective_ids)
            p_text = perspective_getter(pid)
            text = c_text + " " + p_text
            query = QCKQuery(get_pc_cluster_query_id(pc), text)
            query_list.append(query)

    return query_list
예제 #9
0
파일: gen_worker.py 프로젝트: clover3/Chair
    def generate(self, data_id_manager, qids):
        missing_cnt = 0
        success_docs = 0
        n_passage = 0
        for qid in qids:
            if qid not in self.resource.candidate_doc_d:
                assert qid not in self.resource.qrel.qrel_d
                continue

            tokens_d: Dict[str, Tuple[List, List]] = self.resource.get_doc_tokens_d(qid)
            q_tokens = self.resource.get_q_tokens(qid)

            data_size_maybe = 0
            for title_tokens, body_tokens in tokens_d.values():
                data_size_maybe += len(title_tokens)
                data_size_maybe += len(body_tokens)
            for doc_id in self.resource.candidate_doc_d[qid]:
                label = self.resource.get_label(qid, doc_id)
                try:
                    title_tokens, body_tokens = tokens_d[doc_id]
                    insts: List[Tuple[List, List]] = self.doc_encoder.encode(q_tokens, title_tokens, body_tokens)

                    for passage_idx, passage in enumerate(insts):
                        tokens_seg, seg_ids = passage
                        assert type(tokens_seg[0]) == str
                        assert type(seg_ids[0]) == int
                        data_id = data_id_manager.assign({
                            'query': QCKQuery(qid, ""),
                            'candidate': QCKCandidate(doc_id, ""),
                            'passage_idx': passage_idx,
                        })
                        inst = ClassificationInstanceWDataID(tokens_seg, seg_ids, label, data_id)
                        n_passage += 1
                        yield inst
                        # if n_passage % 1000 == 0:
                        #     tprint("n_passage : {}".format(n_passage))
                        #     tprint('gc.get_count()', gc.get_count())
                        #     tprint('gc.get_stats', gc.get_stats())
                    success_docs += 1
                except KeyError:
                    missing_cnt += 1
                    if missing_cnt > 10:
                        print("success: ", success_docs)
                        raise KeyError
        print(" {} of {} has long title".format(self.doc_encoder.long_title_cnt, self.doc_encoder.total_doc_cnt))
예제 #10
0
 def generate(self, query_list, data_id_manager):
     all_insts = []
     for query_id in query_list:
         if query_id not in self.galago_rank:
             continue
         query = self.queries[query_id]
         qck_query = QCKQuery(query_id, "")
         query_tokens = self.tokenizer.tokenize(query)
         for doc_id, _, _ in self.galago_rank[query_id][:self.top_k]:
             tokens = self.data[doc_id]
             passage_list = self.encoder.encode(query_tokens, tokens)
             candidate = QCKCandidate(doc_id, "")
             for idx, (tokens, seg_ids) in enumerate(passage_list):
                 info = {
                     'query': get_light_qckquery(qck_query),
                     'candidate': get_light_qckcandidate(candidate),
                     'idx': idx
                 }
                 data_id = data_id_manager.assign(info)
                 inst = Instance(tokens, seg_ids, data_id, 0)
                 all_insts.append(inst)
     return all_insts
예제 #11
0
    def generate(self, query_list,
                 data_id_manager) -> Iterator[QueryDocInstance]:
        neg_k = self.neg_k
        for query_id in query_list:
            if query_id not in self.judgement:
                continue

            qck_query = QCKQuery(query_id, "")
            judgement = self.judgement[query_id]
            ranked_list = self.galago_rank[query_id]
            ranked_list = ranked_list[:neg_k]

            target_docs = set()
            docs_in_ranked_list = [e.doc_id for e in ranked_list]
            target_docs.update(docs_in_ranked_list)

            if self.include_all_judged:
                docs_in_judgements = judgement.keys()
                target_docs.update(docs_in_judgements)

            print("Total of {} docs".format(len(target_docs)))
            for doc_id in target_docs:
                for tas in self.encoder.encode(query_id, doc_id):
                    label = 1 if doc_id in judgement and judgement[
                        doc_id] > 0 else 0
                    # if label:
                    #     bprint(" -> Label={}".format(label))
                    #     bflush()
                    # else:
                    #     bempty()
                    candidate = QCKCandidate(doc_id, "")
                    info = {
                        'query': get_light_qckquery(qck_query),
                        'candidate': get_light_qckcandidate(candidate),
                    }
                    data_id = data_id_manager.assign(info)
                    inst = ClassificationInstanceWDataID.make_from_tas(
                        tas, label, data_id)
                    yield inst
예제 #12
0
파일: gen_worker.py 프로젝트: clover3/Chair
    def generate(self, data_id_manager, qids):
        missing_cnt = 0
        success_docs = 0
        missing_doc_qid = []
        for qid in qids:
            if qid not in self.resource.candidate_doc_d:
                assert qid not in self.resource.qrel.qrel_d
                continue

            tokens_d = self.resource.get_doc_tokens_d(qid)
            q_tokens = self.resource.get_q_tokens(qid)
            for doc_id in self.resource.get_candidate_doc_d(qid):
                label = self.resource.get_label(qid, doc_id)
                try:
                    doc_tokens = tokens_d[doc_id]
                    insts: List[Tuple[List, List]] = self.encoder.encode(q_tokens, doc_tokens)

                    for passage_idx, passage in enumerate(insts):
                        tokens_seg, seg_ids = passage
                        assert type(tokens_seg[0]) == str
                        assert type(seg_ids[0]) == int
                        data_id = data_id_manager.assign({
                            'query': QCKQuery(qid, ""),
                            'candidate': QCKCandidate(doc_id, ""),
                            'passage_idx': passage_idx,
                        })
                        inst = ClassificationInstanceWDataID(tokens_seg, seg_ids, label, data_id)
                        yield inst
                    success_docs += 1
                except KeyError:
                    missing_cnt += 1
                    missing_doc_qid.append(qid)
                    if missing_cnt > 10:
                        print(missing_doc_qid)
                        print("success: ", success_docs)
                        raise KeyError
예제 #13
0
파일: qck_common.py 프로젝트: clover3/Chair
 def claim_to_query(claim: Dict):
     return QCKQuery(str(claim['cId']), claim['text'])
예제 #14
0
def to_qck_queries(queries):
    qck_queries = []
    for qid, query_text in queries.items():
        e = QCKQuery(qid, query_text)
        qck_queries.append(e)
    return qck_queries
예제 #15
0
def light_query(obj: QCKQueryWToken):
    return QCKQuery(obj.query_id, "")
예제 #16
0
 def problem_to_qckquery(problem: ArguDataPoint):
     return QCKQuery(str(problem.text1.id.id), problem.text1.text)