예제 #1
0
def get_candidate_all_passage_w_samping_predict(
        max_seq_length=256) -> Dict[str, List[QCKCandidateWToken]]:
    qrel_path = os.path.join(data_path, "robust", "qrels.rob04.txt")
    galago_rank = load_bm25_best()
    tokens_d = load_robust_tokens_for_predict(4)
    queries = load_robust04_title_query()
    tokenizer = get_tokenizer()
    out_d: Dict[str, List[QCKCandidateWToken]] = {}
    for query_id in queries:
        query = queries[query_id]
        query_tokens = tokenizer.tokenize(query)

        ranked_list = galago_rank[query_id]
        ranked_list = ranked_list[:100]
        doc_ids = list([e.doc_id for e in ranked_list])

        candidate = []
        for doc_id in doc_ids:
            tokens = tokens_d[doc_id]
            for idx, passage in enumerate(enum_passage(tokens,
                                                       max_seq_length)):
                if idx == 0:
                    include = True
                else:
                    include = random.random() < 0.1

                if include:
                    c = QCKCandidateWToken(doc_id, "", passage)
                    candidate.append(c)

        out_d[query_id] = candidate
    return out_d
예제 #2
0
 def make_candidate(doc_id: str) -> Iterable[QCKCandidateWToken]:
     tokens = token_data[doc_id]
     for idx, passage_tokens in enumerate(
             enum_passage(tokens, content_len)):
         if idx >= max_passage_per_doc:
             break
         doc_part_id = "{}_{}".format(doc_id, idx)
         yield QCKCandidateWToken(doc_part_id, "", passage_tokens)
예제 #3
0
 def count(self, query_tokens, tokens) -> List[Tuple[List, List, int]]:
     content_len = self.max_seq_length - 3 - len(query_tokens)
     insts = []
     for second_tokens in enum_passage(tokens, content_len):
         out_tokens = ["[CLS]"] + query_tokens + [
             "[SEP]"
         ] + second_tokens + ["[SEP]"]
         segment_ids = [0] * (len(query_tokens) +
                              2) + [1] * (len(second_tokens) + 1)
         entry = out_tokens, segment_ids, len(second_tokens)
         insts.append(entry)
     return insts
예제 #4
0
    def get_candidate_for_query(query: QCKQuery):
        res = get_evidence_from_pool(query.text, 60)
        query_len = len(tokenizer.tokenize(query.text))
        candidate_max_len = max_seq_length - 3 - query_len

        output = []
        for text, e_id, score in res:
            tokens = tokenizer.tokenize(text)
            for passage in enum_passage(tokens, candidate_max_len):
                c = QCKCandidateWToken(str(e_id), "", passage)
                output.append(c)
        return output
예제 #5
0
    def encode(self, query_tokens, tokens) -> List[Tuple[List, List]]:
        content_len = self.max_seq_length - 3 - len(query_tokens)
        insts = []
        for idx, second_tokens in enumerate(enum_passage(tokens, content_len)):
            if idx == self.num_segment:
                break

            out_tokens = ["[CLS]"] + query_tokens + [
                "[SEP]"
            ] + second_tokens + ["[SEP]"]
            segment_ids = [0] * (len(query_tokens) +
                                 2) + [1] * (len(second_tokens) + 1)
            entry = out_tokens, segment_ids
            insts.append(entry)
        return insts
예제 #6
0
 def encode(self, query_tokens, tokens) -> List[Tuple[List, List]]:
     content_len = self.max_seq_length - 3 - len(query_tokens)
     insts = []
     for idx, second_tokens in enumerate(enum_passage(tokens, content_len)):
         chance = math.pow(self.g_factor, idx)
         include = random.random() < chance
         if include:
             out_tokens = ["[CLS]"] + query_tokens + [
                 "[SEP]"
             ] + second_tokens + ["[SEP]"]
             segment_ids = [0] * (len(query_tokens) +
                                  2) + [1] * (len(second_tokens) + 1)
             entry = out_tokens, segment_ids
             insts.append(entry)
     return insts
예제 #7
0
 def encode(self, query_tokens, tokens) -> List[Tuple[List, List]]:
     content_per_window = self.sero_window_size - 3 - len(query_tokens)
     sero_content_length = content_per_window * 4
     content_max_len = self.max_seq_length - 3 - len(query_tokens)
     content_len = min(sero_content_length, content_max_len)
     insts = []
     for idx, second_tokens in enumerate(enum_passage(tokens, content_len)):
         out_tokens = ["[CLS]"] + query_tokens + [
             "[SEP]"
         ] + second_tokens + ["[SEP]"]
         segment_ids = [0] * (len(query_tokens) +
                              2) + [1] * (len(second_tokens) + 1)
         entry = out_tokens, segment_ids
         insts.append(entry)
         break
     return insts
예제 #8
0
 def encode(self, query_tokens, title_tokens,
            body_tokens) -> List[Tuple[List, List]]:
     self.total_doc_cnt += 1
     content_len = self.max_seq_length - 3 - len(query_tokens)
     assert content_len > 5
     insts = []
     for second_tokens in enum_passage(body_tokens, content_len):
         passage_tokens = second_tokens
         out_tokens = ["[CLS]"] + query_tokens + [
             "[SEP]"
         ] + passage_tokens + ["[SEP]"]
         segment_ids = [0] * (len(query_tokens) +
                              2) + [1] * (len(passage_tokens) + 1)
         entry = out_tokens, segment_ids
         insts.append(entry)
     return insts
예제 #9
0
 def encode(self, query_tokens, tokens) -> List[Tuple[List, List]]:
     if len(query_tokens) > 64:
         query_tokens = query_tokens[:64]
     content_len = self.max_seq_length - 3 - len(query_tokens)
     if not tokens:
         tokens = ['[PAD]']
     insts = []
     for second_tokens in enum_passage(tokens, content_len):
         out_tokens = ["[CLS]"] + query_tokens + [
             "[SEP]"
         ] + second_tokens + ["[SEP]"]
         segment_ids = [0] * (len(query_tokens) +
                              2) + [1] * (len(second_tokens) + 1)
         entry = out_tokens, segment_ids
         insts.append(entry)
     return insts
예제 #10
0
    def encode(self, query_tokens, tokens) -> List[Tuple[List, List]]:
        content_len = self.window_size - 3 - len(query_tokens)
        tokens_extending = []
        segment_ids_extending = []

        for idx, second_tokens in enumerate(enum_passage(tokens, content_len)):
            if idx == self.num_segment:
                break
            out_tokens = ["[CLS]"] + query_tokens + [
                "[SEP]"
            ] + second_tokens + ["[SEP]"]
            segment_ids = [0] * (len(query_tokens) +
                                 2) + [1] * (len(second_tokens) + 1)

            assert len(tokens_extending) % self.window_size == 0
            assert len(segment_ids_extending) % self.window_size == 0
            tokens_extending.extend(out_tokens)
            segment_ids_extending.extend(segment_ids)
        return [(tokens_extending, segment_ids_extending)]
예제 #11
0
 def encode(self, query_tokens, tokens) -> List[Tuple[List, List]]:
     content_len = self.max_seq_length - 3 - len(query_tokens) - 1
     insts = []
     passages = list(enum_passage(tokens, content_len))
     for idx, second_tokens in enumerate(passages):
         chance = math.pow(self.g_factor, idx)
         include = random.random() < chance
         if include:
             if idx == 0:
                 mark = token_first
             elif idx == len(passages) - 1:
                 mark = token_end
             else:
                 mark = token_mid
             out_tokens = ["[CLS]"] + query_tokens + [
                 "[SEP]", mark
             ] + second_tokens + ["[SEP]"]
             segment_ids = [0] * (len(query_tokens) +
                                  2) + [1] * (len(second_tokens) + 2)
             entry = out_tokens, segment_ids
             insts.append(entry)
     return insts
예제 #12
0
def get_candidate_all_passage_w_samping(
        max_seq_length=256, neg_k=1000) -> Dict[str, List[QCKCandidateWToken]]:
    qrel_path = os.path.join(data_path, "robust", "qrels.rob04.txt")
    galago_rank = load_bm25_best()
    tokens_d = load_robust_tokens_for_train()
    tokens_d.update(load_robust_tokens_for_predict(4))
    queries = load_robust04_title_query()
    tokenizer = get_tokenizer()
    judgement: Dict[str, Dict] = load_qrels_structured(qrel_path)
    out_d: Dict[str, List[QCKCandidateWToken]] = {}
    for query_id in judgement.keys():
        if query_id not in judgement:
            continue
        query = queries[query_id]
        query_tokens = tokenizer.tokenize(query)

        judge_entries = judgement[query_id]
        doc_ids = set(judge_entries.keys())

        ranked_list = galago_rank[query_id]
        ranked_list = ranked_list[:neg_k]
        doc_ids.update([e.doc_id for e in ranked_list])

        candidate = []
        for doc_id in doc_ids:
            tokens = tokens_d[doc_id]
            for idx, passage in enumerate(enum_passage(tokens,
                                                       max_seq_length)):
                if idx == 0:
                    include = True
                else:
                    include = random.random() < 0.1

                if include:
                    c = QCKCandidateWToken(doc_id, "", passage)
                    candidate.append(c)

        out_d[query_id] = candidate
    return out_d
예제 #13
0
 def encode(self, query_tokens, tokens) -> List[Tuple[List, List]]:
     if len(query_tokens) > 64:
         query_tokens = query_tokens[:64]
     content_len = self.max_seq_length - 3 - len(query_tokens) - 1
     insts = []
     if not tokens:
         tokens = ['[PAD]']
     passages = list(enum_passage(tokens, content_len))
     for idx, second_tokens in enumerate(passages):
         if idx == 0:
             mark = token_first
         elif idx == len(passages) - 1:
             mark = token_end
         else:
             mark = token_mid
         out_tokens = ["[CLS]"] + query_tokens + [
             "[SEP]", mark
         ] + second_tokens + ["[SEP]"]
         segment_ids = [0] * (len(query_tokens) +
                              2) + [1] * (len(second_tokens) + 2)
         entry = out_tokens, segment_ids
         insts.append(entry)
     return insts
예제 #14
0
 def make_candidate(e_id: int) -> Iterable[QCKCandidate]:
     text = evi_dict[e_id]
     tokens = tokenizer.tokenize(text)
     for passage in enum_passage(tokens, candidate_max_len):
         yield QCKCandidateWToken(str(e_id), "", passage)