def sentence_payload_gen(q_res_path: str, top_n, data_id_man: DataIDManager): print("loading ranked list") ranked_list: Dict[ str, List[SimpleRankedListEntry]] = load_galago_ranked_list(q_res_path) qid_list = list(ranked_list.keys()) qid_list = qid_list[:10] ranked_list = {k: ranked_list[k] for k in qid_list} print("Pre loading docs") preload_docs(ranked_list, top_n) entries: List[Tuple[str, bool, int]] = [] def enum_sentence(tokens) -> Iterator[str]: text = " ".join(tokens) sents = sent_tokenize(text) yield from sents ticker = TimeEstimator(len(ranked_list)) for qid in ranked_list: q_res: List[SimpleRankedListEntry] = ranked_list[qid] docs = iterate_docs(q_res, top_n) for doc in docs: for sent_idx, sent in enumerate(enum_sentence(doc.tokens)): info = { 'doc_id': doc.doc_id, 'sent_idx': sent_idx, 'sentence': sent } data_id = data_id_man.assign(info) e = sent, True, data_id entries.append(e) ticker.tick() return entries
def collect_info_transform(data: Iterable[Tuple[QCKQuery, QCKCandidate, bool]], data_id_man: DataIDManager) \ -> Iterable[QCInstance]: for query, candidate, is_correct in data: info = { 'query': get_light_qckquery(query), 'candidate': get_light_qckcandidate(candidate) } yield QCInstance(query.text, candidate.text, data_id_man.assign(info), int(is_correct))
def generate(self, data_id_manager: DataIDManager, query_list) -> List[ClassificationInstanceWDataID]: neg_k = 1000 all_insts = [] for query_id in query_list: if query_id not in self.judgement: continue judgement = self.judgement[query_id] query = self.queries[query_id] query_tokens = self.tokenizer.tokenize(query) ranked_list = self.galago_rank[query_id] ranked_list = ranked_list[:neg_k] target_docs = set(judgement.keys()) target_docs.update([e.doc_id for e in ranked_list]) print("Total of {} docs".format(len(target_docs))) for doc_id in target_docs: tokens = self.data[doc_id] insts: List[Tuple[List, List]] = self.encoder.encode( query_tokens, tokens) label = 1 if doc_id in judgement and judgement[ doc_id] > 0 else 0 if label: passage_scores = list([ self.scores[query_id, doc_id, idx] for idx, _ in enumerate(insts) ]) target_indices = self.get_target_indices(passage_scores) else: target_indices = [0] n = len(insts) if random.random() < 0.1 and n > 1: idx = random.randint(1, n - 1) target_indices.append(idx) for passage_idx in target_indices: tokens_seg, seg_ids = insts[passage_idx] assert type(tokens_seg[0]) == str assert type(seg_ids[0]) == int data_id = data_id_manager.assign({ 'doc_id': doc_id, 'passage_idx': passage_idx, 'label': label, 'tokens': tokens_seg, 'seg_ids': seg_ids, }) all_insts.append( ClassificationInstanceWDataID(tokens_seg, seg_ids, label, data_id)) return all_insts
def main(): data_id_manager = DataIDManager() data = [] for text in enum_f5_data(): info = { 'text': text, } data_id = data_id_manager.assign(info) label = 0 data.append(TextInstance(text, label, data_id)) encode_fn = get_encode_fn_w_data_id(512, False) save_path = at_output_dir("clue_counter_arg", "clue_f5.tfrecord") write_records_w_encode_fn(save_path, encode_fn, data) info_save_path = at_output_dir("clue_counter_arg", "clue_f5.tfrecord.info") json.dump(data_id_manager.id_to_info, open(info_save_path, "w"))
def work(self, job_id): qids = self.query_group[job_id] max_data_per_job = 1000 * 1000 base = job_id * max_data_per_job data_id_manager = DataIDManager(base, base+max_data_per_job) output_path = os.path.join(self.out_dir, str(job_id)) writer = RecordWriterWrap(output_path) for qid in qids: try: sr_per_qid = self.seg_resource_loader.load_for_qid(qid) docs_to_predict = select_one_pos_neg_doc(sr_per_qid.sr_per_query_doc) for sr_per_doc in docs_to_predict: label_id = sr_per_doc.label if self.skip_single_seg and len(sr_per_doc.segs) == 1: continue for seg_idx, seg in enumerate(sr_per_doc.segs): info = { 'qid': qid, 'doc_id': sr_per_doc.doc_id, 'seg_idx': seg_idx } data_id = data_id_manager.assign(info) feature = encode_sr(seg, self.max_seq_length, label_id, data_id) writer.write_feature(feature) except FileNotFoundError: if qid in missing_qids: pass else: raise writer.close() info_save_path = os.path.join(self.info_dir, "{}.info".format(job_id)) json.dump(data_id_manager.id_to_info, open(info_save_path, "w"))
def generate(self, triplet_itr: Iterator[Tuple[Query, Doc, Doc]]): data_id_manager = DataIDManager() for q, d1, d2 in triplet_itr: data_id = data_id_manager.assign({}) inst = self.triplet_to_paired_instance(q, d1, d2, data_id) yield inst