def __init__(self, retrieved_path, embed_path): super().__init__() retrieved = [json.loads(l) for l in open(retrieved_path).readlines()] self.para_embed = np.load(embed_path) self.qid2para = {} for item in retrieved: self.qid2para[hash_question(item["question"])] = { "para_embed_idx": item["para_embed_idx"], "para_labels": item["para_labels"] }
def __init__(self, raw_data, tokenizer, max_query_length, max_length, db, para_embed, index2paraid='retrieval/index_data/idx_id.json', matched_para_path="", exact_search=False, cased=False, regex=False): self.max_length = max_length self.max_query_length = max_query_length self.para_embed = para_embed self.cased = cased # spanbert used cased tokenization self.regex = regex if self.cased: self.cased_tokenizer = BertTokenizer.from_pretrained( 'bert-base-cased') # if not exact_search: quantizer = faiss.IndexFlatIP(128) self.index = faiss.IndexIVFFlat(quantizer, 128, 100) self.index.train(self.para_embed) self.index.add(self.para_embed) self.index.nprobe = 20 # else: # self.index = faiss.IndexFlatIP(128) # self.index.add(self.para_embed) self.tokenizer = tokenizer self.qa_data = [json.loads(l) for l in open(raw_data).readlines()] self.index2paraid = json.load(open(index2paraid)) self.para_db = db self.matched_para_path = matched_para_path if self.matched_para_path != "": print(f"Load matched gold paras from {self.matched_para_path}") annotated = [ json.loads(l) for l in tqdm(open(self.matched_para_path).readlines()) ] self.qid2goldparas = { hash_question(item["question"]): item["matched_paras"] for item in annotated } self.basic_tokenizer = SimpleTokenizer()
def debug(retrieved="../data/wq_finetuneq_dev_5000.txt", raw_data="../data/wq-dev.txt", precomputed="../data/wq_ft_dev_matched.txt", k=10): # check wether it reasonable to precompute a paragraph set retrieved = [json.loads(l) for l in open(retrieved).readlines()] raw_data = [json.loads(l) for l in open(raw_data).readlines()] annotated = [json.loads(l) for l in open(precomputed).readlines()] qid2goldparas = { hash_question(item["question"]): item["matched_paras"] for item in annotated } topk_covered = [] for qa, result in tqdm(zip(raw_data, retrieved), total=len(raw_data)): qid = hash_question(qa["question"]) covered = 0 for para_id in result["para_id"][:k]: if para_id in qid2goldparas[qid]: covered = 1 break topk_covered.append(covered) print(np.mean(topk_covered))
def eval_load(self, retriever, k=5): for qa in self.qa_data: with torch.no_grad(): q_ids = torch.LongTensor( self.tokenizer.encode( qa["question"], max_length=self.max_query_length)).view(1, -1).cuda() q_masks = torch.ones(q_ids.shape).bool().view(1, -1).cuda() q_cls = retriever.bert_q(q_ids, q_masks)[1] q_embed = retriever.proj_q(q_cls).data.cpu().numpy().astype( 'float32') _, I = self.index.search(q_embed, k) para_embed_idx = I.reshape(-1) para_idx = [self.index2paraid[str(_)] for _ in para_embed_idx] paras = [ normalize(self.para_db.get_doc_text(idx)) for idx in para_idx ] para_embeds = self.para_embed[para_embed_idx] if self.cased: q_ids_cased = torch.LongTensor( self.cased_tokenizer.encode( qa["question"], max_length=self.max_query_length)).view(1, -1) batched_examples = [] # match answer spans in top5 paras for p in paras: p = normalize(p) tokenizer = self.cased_tokenizer if self.cased else self.tokenizer doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare( p, tokenizer) batched_examples.append({ "qid": hash_question(qa["question"]), "q": qa["question"], "true_answers": qa["answer"], "doc_toks": doc_tokens, "doc_subtoks": all_doc_tokens, "tok_to_orig_index": tok_to_orig_index, }) for item in batched_examples: item["input_ids_q"] = q_ids.view(-1).cpu() if self.cased: item["input_ids_q_cased"] = q_ids_cased.view(-1) para_offset = item["input_ids_q_cased"].size(0) else: para_offset = item["input_ids_q"].size(0) max_toks_for_doc = self.max_length - para_offset - 1 para_subtoks = item["doc_subtoks"] if len(para_subtoks) > max_toks_for_doc: para_subtoks = para_subtoks[:max_toks_for_doc] if self.cased: p_ids = self.cased_tokenizer.convert_tokens_to_ids( para_subtoks) else: p_ids = self.tokenizer.convert_tokens_to_ids(para_subtoks) item["input_ids_c"] = self._add_special_token( torch.LongTensor(p_ids)) paragraph = item["input_ids_c"][1:-1] if self.cased: item["input_ids"], item["segment_ids"] = self._join_sents( item["input_ids_q_cased"][1:-1], item["input_ids_c"][1:-1]) else: item["input_ids"], item["segment_ids"] = self._join_sents( item["input_ids_q"][1:-1], item["input_ids_c"][1:-1]) item["para_offset"] = para_offset item["paragraph_mask"] = torch.zeros( item["input_ids"].shape).bool() item["paragraph_mask"][para_offset:-1] = 1 yield self.collate(batched_examples, para_embeds)
def load(self, retriever, k=5): for qa in self.qa_data: with torch.no_grad(): q_ids = torch.LongTensor( self.tokenizer.encode( qa["question"], max_length=self.max_query_length)).view(1, -1).cuda() q_masks = torch.ones(q_ids.shape).bool().view(1, -1).cuda() q_cls = retriever.bert_q(q_ids, q_masks)[1] q_embed = retriever.proj_q(q_cls).data.cpu().numpy().astype( 'float32') _, I = self.index.search(q_embed, 5000) # retrieve para_embed_idx = I.reshape(-1) if self.cased: q_ids_cased = torch.LongTensor( self.cased_tokenizer.encode( qa["question"], max_length=self.max_query_length)).view(1, -1) para_idx = [self.index2paraid[str(_)] for _ in para_embed_idx] para_embeds = self.para_embed[para_embed_idx] qid = hash_question(qa["question"]) gold_paras = self.qid2goldparas[qid] # match answer strings p_labels = [] batched_examples = [] topk5000_labels = [int(_ in gold_paras) for _ in para_idx] # match answer spans in top5 paras for p_idx in para_idx[:k]: p = normalize(self.para_db.get_doc_text(p_idx)) # p_covered, matched_string = para_has_answer(p, qa["answer"], self.basic_tokenizer) matched_spans = match_answer_span( p, qa["answer"], self.basic_tokenizer, match="regex" if self.regex else "string") p_covered = int(len(matched_spans) > 0) ans_starts, ans_ends, ans_texts = [], [], [] if self.cased: doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare( p, self.cased_tokenizer) else: doc_tokens, char_to_word_offset, orig_to_tok_index, tok_to_orig_index, all_doc_tokens = prepare( p, self.tokenizer) if p_covered: for matched_string in matched_spans: char_starts = [ i for i in range(len(p)) if p.startswith(matched_string, i) ] if len(char_starts) > 0: char_ends = [ start + len(matched_string) - 1 for start in char_starts ] answer = { "text": matched_string, "char_spans": list(zip(char_starts, char_ends)) } if self.cased: ans_spans = find_ans_span_with_char_offsets( answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, self.cased_tokenizer) else: ans_spans = find_ans_span_with_char_offsets( answer, char_to_word_offset, doc_tokens, all_doc_tokens, orig_to_tok_index, self.tokenizer) for s, e in ans_spans: ans_starts.append(s) ans_ends.append(e) ans_texts.append(matched_string) batched_examples.append({ "qid": hash_question(qa["question"]), "q": qa["question"], "true_answers": qa["answer"], "doc_subtoks": all_doc_tokens, "starts": ans_starts, "ends": ans_ends, "covered": p_covered }) # # look up saved # if p_idx in gold_paras: # p_covered = 1 # all_doc_tokens = gold_paras[p_idx]["doc_subtoks"] # ans_starts = gold_paras[p_idx]["starts"] # ans_ends = gold_paras[p_idx]["ends"] # ans_texts = gold_paras[p_idx]["span_texts"] # else: # p_covered = 0 # p = normalize(self.para_db.get_doc_text(p_idx)) # _, _, _, _, all_doc_tokens = prepare(p, self.tokenizer) # ans_starts, ans_ends, ans_texts = [], [], [] # batched_examples.append({ # "qid": hash_question(qa["question"]), # "q": qa["question"], # "true_answers": qa["answer"], # "doc_subtoks": all_doc_tokens, # "starts": ans_starts, # "ends": ans_ends, # "covered": p_covered # }) p_labels.append(int(p_covered)) # calculate loss only when the top5000 covered the answer passage if np.sum(topk5000_labels) > 0 or np.sum(p_labels) > 0: # training tensors for item in batched_examples: item["input_ids_q"] = q_ids.view(-1).cpu() if self.cased: item["input_ids_q_cased"] = q_ids_cased.view(-1) para_offset = item["input_ids_q_cased"].size(0) else: para_offset = item["input_ids_q"].size(0) max_toks_for_doc = self.max_length - para_offset - 1 para_subtoks = item["doc_subtoks"] if len(para_subtoks) > max_toks_for_doc: para_subtoks = para_subtoks[:max_toks_for_doc] if self.cased: p_ids = self.cased_tokenizer.convert_tokens_to_ids( para_subtoks) else: p_ids = self.tokenizer.convert_tokens_to_ids( para_subtoks) item["input_ids_c"] = self._add_special_token( torch.LongTensor(p_ids)) paragraph = item["input_ids_c"][1:-1] if self.cased: item["input_ids"], item[ "segment_ids"] = self._join_sents( item["input_ids_q_cased"][1:-1], item["input_ids_c"][1:-1]) else: item["input_ids"], item[ "segment_ids"] = self._join_sents( item["input_ids_q"][1:-1], item["input_ids_c"][1:-1]) item["para_offset"] = para_offset item["paragraph_mask"] = torch.zeros( item["input_ids"].shape).bool() item["paragraph_mask"][para_offset:-1] = 1 starts, ends, covered = item["starts"], item["ends"], item[ "covered"] start_positions, end_positions = [], [] covered = item["covered"] if covered: covered = 0 for s, e in zip(starts, ends): assert s <= e if s >= paragraph.size(0): continue else: start_position = min( s, paragraph.size(0) - 1) + para_offset end_position = min( e, paragraph.size(0) - 1) + para_offset covered = 1 start_positions.append(start_position) end_positions.append(end_position) if len(start_positions) == 0: assert not covered start_positions.append(-1) end_positions.append(-1) start_tensor, end_tensor, covered = torch.LongTensor( start_positions), torch.LongTensor( end_positions), torch.LongTensor([covered]) item["start"] = start_tensor item["end"] = end_tensor item["covered"] = covered yield self.collate(batched_examples, para_embeds, topk5000_labels) else: yield {}