コード例 #1
0
    def generate(self, data_id_manager: DataIDManager,
                 query_list) -> List[ClassificationInstanceWDataID]:
        neg_k = 1000
        all_insts = []
        for query_id in query_list:
            if query_id not in self.judgement:
                continue

            judgement = self.judgement[query_id]
            query = self.queries[query_id]
            query_tokens = self.tokenizer.tokenize(query)

            ranked_list = self.galago_rank[query_id]
            ranked_list = ranked_list[:neg_k]

            target_docs = set(judgement.keys())
            target_docs.update([e.doc_id for e in ranked_list])
            print("Total of {} docs".format(len(target_docs)))

            for doc_id in target_docs:
                tokens = self.data[doc_id]
                insts: List[Tuple[List, List]] = self.encoder.encode(
                    query_tokens, tokens)
                label = 1 if doc_id in judgement and judgement[
                    doc_id] > 0 else 0
                if label:
                    passage_scores = list([
                        self.scores[query_id, doc_id, idx]
                        for idx, _ in enumerate(insts)
                    ])
                    target_indices = self.get_target_indices(passage_scores)
                else:
                    target_indices = [0]
                    n = len(insts)
                    if random.random() < 0.1 and n > 1:
                        idx = random.randint(1, n - 1)
                        target_indices.append(idx)

                for passage_idx in target_indices:
                    tokens_seg, seg_ids = insts[passage_idx]
                    assert type(tokens_seg[0]) == str
                    assert type(seg_ids[0]) == int
                    data_id = data_id_manager.assign({
                        'doc_id': doc_id,
                        'passage_idx': passage_idx,
                        'label': label,
                        'tokens': tokens_seg,
                        'seg_ids': seg_ids,
                    })
                    all_insts.append(
                        ClassificationInstanceWDataID(tokens_seg, seg_ids,
                                                      label, data_id))

        return all_insts
コード例 #2
0
ファイル: max_sent_encode.py プロジェクト: clover3/Chair
    def generate(self, data_id_manager, qids):
        missing_cnt = 0
        success_docs = 0
        missing_doc_qid = []
        for qid in qids:
            if qid not in self.resource.get_doc_for_query_d():
                # assert not self.resource.query_in_qrel(qid)
                continue

            query_text = self.resource.get_query_text(qid)
            bert_tokens_d = self.resource.get_bert_tokens_d(qid)
            stemmed_tokens_d = self.resource.get_stemmed_tokens_d(qid)
            for doc_id in self.resource.get_doc_for_query_d()[qid]:
                label = self.resource.get_label(qid, doc_id)
                try:
                    bert_title_tokens, bert_body_tokens_list = bert_tokens_d[
                        doc_id]
                    stemmed_title_tokens, stemmed_body_tokens_list = stemmed_tokens_d[
                        doc_id]
                    insts: List[Tuple[List, List]]\
                        = self.encoder.encode(
                            query_text,
                            stemmed_title_tokens,
                            stemmed_body_tokens_list,
                            bert_title_tokens,
                            bert_body_tokens_list
                    )

                    for passage_idx, passage in enumerate(insts):
                        tokens_seg, seg_ids = passage
                        assert type(tokens_seg[0]) == str
                        assert type(seg_ids[0]) == int
                        data_id = data_id_manager.assign({
                            'query':
                            QCKQuery(qid, ""),
                            'candidate':
                            QCKCandidate(doc_id, ""),
                            'passage_idx':
                            passage_idx,
                        })
                        inst = ClassificationInstanceWDataID(
                            tokens_seg, seg_ids, label, data_id)
                        yield inst
                    success_docs += 1
                except KeyError:
                    missing_cnt += 1
                    missing_doc_qid.append(qid)
                    if missing_cnt > 10:
                        print(missing_doc_qid)
                        print("success: ", success_docs)
                        raise KeyError
コード例 #3
0
    def generate(self, query_list,
                 data_id_manager) -> Iterator[QueryDocInstance]:
        neg_k = self.neg_k
        for query_id in query_list:
            if query_id not in self.judgement:
                continue

            qck_query = QCKQuery(query_id, "")
            judgement = self.judgement[query_id]
            ranked_list = self.galago_rank[query_id]
            ranked_list = ranked_list[:neg_k]

            target_docs = set()
            docs_in_ranked_list = [e.doc_id for e in ranked_list]
            target_docs.update(docs_in_ranked_list)

            if self.include_all_judged:
                docs_in_judgements = judgement.keys()
                target_docs.update(docs_in_judgements)

            print("Total of {} docs".format(len(target_docs)))
            for doc_id in target_docs:
                for tas in self.encoder.encode(query_id, doc_id):
                    label = 1 if doc_id in judgement and judgement[
                        doc_id] > 0 else 0
                    # if label:
                    #     bprint(" -> Label={}".format(label))
                    #     bflush()
                    # else:
                    #     bempty()
                    candidate = QCKCandidate(doc_id, "")
                    info = {
                        'query': get_light_qckquery(qck_query),
                        'candidate': get_light_qckcandidate(candidate),
                    }
                    data_id = data_id_manager.assign(info)
                    inst = ClassificationInstanceWDataID.make_from_tas(
                        tas, label, data_id)
                    yield inst
コード例 #4
0
    def generate(self, data_id_manager, qids):
        missing_cnt = 0
        success_docs = 0
        missing_doc_qid = []
        for qid in qids:
            if qid not in self.resource.candidate_doc_d:
                continue

            docs: List[MSMarcoDoc] = load_per_query_docs(qid, None)
            docs_d = {d.doc_id: d for d in docs}

            q_tokens = self.resource.get_q_tokens(qid)
            for doc_id in self.resource.candidate_doc_d[qid]:
                label = self.resource.get_label(qid, doc_id)
                try:
                    doc = docs_d[doc_id]
                    insts: List[Tuple[List, List]] = self.encoder.encode(q_tokens, doc.title, doc.body)
                    for passage_idx, passage in enumerate(insts):
                        tokens_seg, seg_ids = passage
                        assert type(tokens_seg[0]) == str
                        assert type(seg_ids[0]) == int
                        data_id = data_id_manager.assign({
                            'query': QCKQuery(qid, ""),
                            'candidate': QCKCandidate(doc_id, ""),
                            'passage_idx': passage_idx,
                        })
                        inst = ClassificationInstanceWDataID(tokens_seg, seg_ids, label, data_id)
                        yield inst
                    success_docs += 1
                except KeyError:
                    missing_cnt += 1
                    missing_doc_qid.append(qid)
                    if missing_cnt > 10:
                        print(missing_doc_qid)
                        print("success: ", success_docs)
                        raise KeyError