示例#1
0
    def __init__(self, retrieved_path, embed_path):
        super().__init__()
        retrieved = [json.loads(l) for l in open(retrieved_path).readlines()]
        self.para_embed = np.load(embed_path)

        self.qid2para = {}
        for item in retrieved:
            self.qid2para[hash_question(item["question"])] = {
                "para_embed_idx": item["para_embed_idx"],
                "para_labels": item["para_labels"]
            }
示例#2
0
    def __init__(self,
                 raw_data,
                 tokenizer,
                 max_query_length,
                 max_length,
                 db,
                 para_embed,
                 index2paraid='retrieval/index_data/idx_id.json',
                 matched_para_path="",
                 exact_search=False,
                 cased=False,
                 regex=False):

        self.max_length = max_length
        self.max_query_length = max_query_length
        self.para_embed = para_embed
        self.cased = cased  # spanbert used cased tokenization
        self.regex = regex

        if self.cased:
            self.cased_tokenizer = BertTokenizer.from_pretrained(
                'bert-base-cased')

        # if not exact_search:
        quantizer = faiss.IndexFlatIP(128)
        self.index = faiss.IndexIVFFlat(quantizer, 128, 100)
        self.index.train(self.para_embed)
        self.index.add(self.para_embed)
        self.index.nprobe = 20
        # else:
        #     self.index = faiss.IndexFlatIP(128)
        #     self.index.add(self.para_embed)

        self.tokenizer = tokenizer
        self.qa_data = [json.loads(l) for l in open(raw_data).readlines()]
        self.index2paraid = json.load(open(index2paraid))
        self.para_db = db
        self.matched_para_path = matched_para_path
        if self.matched_para_path != "":
            print(f"Load matched gold paras from {self.matched_para_path}")
            annotated = [
                json.loads(l)
                for l in tqdm(open(self.matched_para_path).readlines())
            ]
            self.qid2goldparas = {
                hash_question(item["question"]): item["matched_paras"]
                for item in annotated
            }

        self.basic_tokenizer = SimpleTokenizer()
示例#3
0
def debug(retrieved="../data/wq_finetuneq_dev_5000.txt",
          raw_data="../data/wq-dev.txt",
          precomputed="../data/wq_ft_dev_matched.txt",
          k=10):
    # check wether it reasonable to precompute a paragraph set
    retrieved = [json.loads(l) for l in open(retrieved).readlines()]
    raw_data = [json.loads(l) for l in open(raw_data).readlines()]

    annotated = [json.loads(l) for l in open(precomputed).readlines()]
    qid2goldparas = {
        hash_question(item["question"]): item["matched_paras"]
        for item in annotated
    }

    topk_covered = []
    for qa, result in tqdm(zip(raw_data, retrieved), total=len(raw_data)):
        qid = hash_question(qa["question"])
        covered = 0
        for para_id in result["para_id"][:k]:
            if para_id in qid2goldparas[qid]:
                covered = 1
                break
        topk_covered.append(covered)
    print(np.mean(topk_covered))
示例#4
0
    def eval_load(self, retriever, k=5):
        for qa in self.qa_data:
            with torch.no_grad():
                q_ids = torch.LongTensor(
                    self.tokenizer.encode(
                        qa["question"],
                        max_length=self.max_query_length)).view(1, -1).cuda()
                q_masks = torch.ones(q_ids.shape).bool().view(1, -1).cuda()
                q_cls = retriever.bert_q(q_ids, q_masks)[1]
                q_embed = retriever.proj_q(q_cls).data.cpu().numpy().astype(
                    'float32')
            _, I = self.index.search(q_embed, k)
            para_embed_idx = I.reshape(-1)
            para_idx = [self.index2paraid[str(_)] for _ in para_embed_idx]
            paras = [
                normalize(self.para_db.get_doc_text(idx)) for idx in para_idx
            ]
            para_embeds = self.para_embed[para_embed_idx]

            if self.cased:
                q_ids_cased = torch.LongTensor(
                    self.cased_tokenizer.encode(
                        qa["question"],
                        max_length=self.max_query_length)).view(1, -1)

            batched_examples = []
            # match answer spans in top5 paras
            for p in paras:
                p = normalize(p)

                tokenizer = self.cased_tokenizer if self.cased else self.tokenizer
                doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare(
                    p, tokenizer)

                batched_examples.append({
                    "qid": hash_question(qa["question"]),
                    "q": qa["question"],
                    "true_answers": qa["answer"],
                    "doc_toks": doc_tokens,
                    "doc_subtoks": all_doc_tokens,
                    "tok_to_orig_index": tok_to_orig_index,
                })

            for item in batched_examples:
                item["input_ids_q"] = q_ids.view(-1).cpu()

                if self.cased:
                    item["input_ids_q_cased"] = q_ids_cased.view(-1)
                    para_offset = item["input_ids_q_cased"].size(0)
                else:
                    para_offset = item["input_ids_q"].size(0)
                max_toks_for_doc = self.max_length - para_offset - 1
                para_subtoks = item["doc_subtoks"]
                if len(para_subtoks) > max_toks_for_doc:
                    para_subtoks = para_subtoks[:max_toks_for_doc]
                if self.cased:
                    p_ids = self.cased_tokenizer.convert_tokens_to_ids(
                        para_subtoks)
                else:
                    p_ids = self.tokenizer.convert_tokens_to_ids(para_subtoks)
                item["input_ids_c"] = self._add_special_token(
                    torch.LongTensor(p_ids))
                paragraph = item["input_ids_c"][1:-1]
                if self.cased:
                    item["input_ids"], item["segment_ids"] = self._join_sents(
                        item["input_ids_q_cased"][1:-1],
                        item["input_ids_c"][1:-1])
                else:
                    item["input_ids"], item["segment_ids"] = self._join_sents(
                        item["input_ids_q"][1:-1], item["input_ids_c"][1:-1])
                item["para_offset"] = para_offset
                item["paragraph_mask"] = torch.zeros(
                    item["input_ids"].shape).bool()
                item["paragraph_mask"][para_offset:-1] = 1

            yield self.collate(batched_examples, para_embeds)
示例#5
0
    def load(self, retriever, k=5):
        for qa in self.qa_data:
            with torch.no_grad():
                q_ids = torch.LongTensor(
                    self.tokenizer.encode(
                        qa["question"],
                        max_length=self.max_query_length)).view(1, -1).cuda()
                q_masks = torch.ones(q_ids.shape).bool().view(1, -1).cuda()
                q_cls = retriever.bert_q(q_ids, q_masks)[1]
                q_embed = retriever.proj_q(q_cls).data.cpu().numpy().astype(
                    'float32')

            _, I = self.index.search(q_embed, 5000)  # retrieve
            para_embed_idx = I.reshape(-1)

            if self.cased:
                q_ids_cased = torch.LongTensor(
                    self.cased_tokenizer.encode(
                        qa["question"],
                        max_length=self.max_query_length)).view(1, -1)

            para_idx = [self.index2paraid[str(_)] for _ in para_embed_idx]
            para_embeds = self.para_embed[para_embed_idx]

            qid = hash_question(qa["question"])
            gold_paras = self.qid2goldparas[qid]

            # match answer strings
            p_labels = []
            batched_examples = []
            topk5000_labels = [int(_ in gold_paras) for _ in para_idx]

            # match answer spans in top5 paras
            for p_idx in para_idx[:k]:
                p = normalize(self.para_db.get_doc_text(p_idx))
                # p_covered, matched_string = para_has_answer(p, qa["answer"], self.basic_tokenizer)
                matched_spans = match_answer_span(
                    p,
                    qa["answer"],
                    self.basic_tokenizer,
                    match="regex" if self.regex else "string")
                p_covered = int(len(matched_spans) > 0)
                ans_starts, ans_ends, ans_texts = [], [], []

                if self.cased:
                    doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare(
                        p, self.cased_tokenizer)
                else:
                    doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare(
                        p, self.tokenizer)

                if p_covered:
                    for matched_string in matched_spans:
                        char_starts = [
                            i for i in range(len(p))
                            if p.startswith(matched_string, i)
                        ]
                        if len(char_starts) > 0:
                            char_ends = [
                                start + len(matched_string) - 1
                                for start in char_starts
                            ]
                            answer = {
                                "text": matched_string,
                                "char_spans": list(zip(char_starts, char_ends))
                            }

                            if self.cased:
                                ans_spans = find_ans_span_with_char_offsets(
                                    answer, char_to_word_offset, doc_tokens,
                                    all_doc_tokens, orig_to_tok_index,
                                    self.cased_tokenizer)
                            else:
                                ans_spans = find_ans_span_with_char_offsets(
                                    answer, char_to_word_offset, doc_tokens,
                                    all_doc_tokens, orig_to_tok_index,
                                    self.tokenizer)

                            for s, e in ans_spans:
                                ans_starts.append(s)
                                ans_ends.append(e)
                                ans_texts.append(matched_string)
                batched_examples.append({
                    "qid": hash_question(qa["question"]),
                    "q": qa["question"],
                    "true_answers": qa["answer"],
                    "doc_subtoks": all_doc_tokens,
                    "starts": ans_starts,
                    "ends": ans_ends,
                    "covered": p_covered
                })

                # # look up saved
                # if p_idx in gold_paras:
                #     p_covered = 1
                #     all_doc_tokens = gold_paras[p_idx]["doc_subtoks"]
                #     ans_starts = gold_paras[p_idx]["starts"]
                #     ans_ends = gold_paras[p_idx]["ends"]
                #     ans_texts = gold_paras[p_idx]["span_texts"]
                # else:
                #     p_covered = 0
                #     p = normalize(self.para_db.get_doc_text(p_idx))
                #     _, _, _, _, all_doc_tokens = prepare(p, self.tokenizer)
                #     ans_starts, ans_ends, ans_texts = [], [], []

                # batched_examples.append({
                #     "qid": hash_question(qa["question"]),
                #     "q": qa["question"],
                #     "true_answers": qa["answer"],
                #     "doc_subtoks": all_doc_tokens,
                #     "starts": ans_starts,
                #     "ends": ans_ends,
                #     "covered": p_covered
                # })
                p_labels.append(int(p_covered))

            # calculate loss only when the top5000 covered the answer passage
            if np.sum(topk5000_labels) > 0 or np.sum(p_labels) > 0:
                # training tensors
                for item in batched_examples:
                    item["input_ids_q"] = q_ids.view(-1).cpu()

                    if self.cased:
                        item["input_ids_q_cased"] = q_ids_cased.view(-1)
                        para_offset = item["input_ids_q_cased"].size(0)
                    else:
                        para_offset = item["input_ids_q"].size(0)

                    max_toks_for_doc = self.max_length - para_offset - 1
                    para_subtoks = item["doc_subtoks"]
                    if len(para_subtoks) > max_toks_for_doc:
                        para_subtoks = para_subtoks[:max_toks_for_doc]

                    if self.cased:
                        p_ids = self.cased_tokenizer.convert_tokens_to_ids(
                            para_subtoks)
                    else:
                        p_ids = self.tokenizer.convert_tokens_to_ids(
                            para_subtoks)
                    item["input_ids_c"] = self._add_special_token(
                        torch.LongTensor(p_ids))
                    paragraph = item["input_ids_c"][1:-1]
                    if self.cased:
                        item["input_ids"], item[
                            "segment_ids"] = self._join_sents(
                                item["input_ids_q_cased"][1:-1],
                                item["input_ids_c"][1:-1])
                    else:
                        item["input_ids"], item[
                            "segment_ids"] = self._join_sents(
                                item["input_ids_q"][1:-1],
                                item["input_ids_c"][1:-1])
                    item["para_offset"] = para_offset
                    item["paragraph_mask"] = torch.zeros(
                        item["input_ids"].shape).bool()
                    item["paragraph_mask"][para_offset:-1] = 1

                    starts, ends, covered = item["starts"], item["ends"], item[
                        "covered"]
                    start_positions, end_positions = [], []

                    covered = item["covered"]
                    if covered:
                        covered = 0
                        for s, e in zip(starts, ends):
                            assert s <= e
                            if s >= paragraph.size(0):
                                continue
                            else:
                                start_position = min(
                                    s,
                                    paragraph.size(0) - 1) + para_offset
                                end_position = min(
                                    e,
                                    paragraph.size(0) - 1) + para_offset
                                covered = 1
                                start_positions.append(start_position)
                                end_positions.append(end_position)
                    if len(start_positions) == 0:
                        assert not covered
                        start_positions.append(-1)
                        end_positions.append(-1)

                    start_tensor, end_tensor, covered = torch.LongTensor(
                        start_positions), torch.LongTensor(
                            end_positions), torch.LongTensor([covered])

                    item["start"] = start_tensor
                    item["end"] = end_tensor
                    item["covered"] = covered

                yield self.collate(batched_examples, para_embeds,
                                   topk5000_labels)
            else:
                yield {}