def __init__(self, max_seq_length, out_dir): self.query_group: List[List[QueryID]] = load_query_group("train") self.seg_resource_loader = SegmentResourceLoader(job_man_dir, "train") self.max_seq_length = max_seq_length self.out_dir = out_dir self.info_dir = self.out_dir + "_info" exist_or_mkdir(self.info_dir)
def __init__(self, max_seq_length, split, skip_single_seg, pick_for_pairwise, out_dir): self.query_group: List[List[QueryID]] = load_query_group(split) self.seg_resource_loader = SegmentResourceLoader(job_man_dir, split) self.max_seq_length = max_seq_length self.out_dir = out_dir self.skip_single_seg = skip_single_seg self.pick_for_pairwise = pick_for_pairwise self.info_dir = self.out_dir + "_info" exist_or_mkdir(self.info_dir)
def seg_resource_loader_test(): srl = SegmentResourceLoader(job_man_dir, "train") qid = "1000008" sr_per_query: SRPerQuery = srl.load_for_qid(qid) assert qid == sr_per_query.qid for sr in sr_per_query.sr_per_query_doc: print(sr.doc_id) print(len(sr.segs)) for s in sr.segs: print(s.first_seg) print(s.second_seg) break break
class SingleSegTrainGen: def __init__(self, max_seq_length, out_dir): self.query_group: List[List[QueryID]] = load_query_group("train") self.seg_resource_loader = SegmentResourceLoader(job_man_dir, "train") self.max_seq_length = max_seq_length self.out_dir = out_dir self.info_dir = self.out_dir + "_info" exist_or_mkdir(self.info_dir) def work(self, job_id): qids = self.query_group[job_id] output_path = os.path.join(self.out_dir, str(job_id)) writer = RecordWriterWrap(output_path) for qid in qids: try: sr_per_qid = self.seg_resource_loader.load_for_qid(qid) for sr_per_doc in sr_per_qid.sr_per_query_doc: if len(sr_per_doc.segs) > 1: continue label_id = sr_per_doc.label seg = sr_per_doc.segs[0] feature = encode_sr(seg, self.max_seq_length, label_id, ) writer.write_feature(feature) except FileNotFoundError: if qid in missing_qids: pass else: raise writer.close()
class BestSegmentPredictionGen: def __init__(self, max_seq_length, split, skip_single_seg, pick_for_pairwise, out_dir): self.query_group: List[List[QueryID]] = load_query_group(split) self.seg_resource_loader = SegmentResourceLoader(job_man_dir, split) self.max_seq_length = max_seq_length self.out_dir = out_dir self.skip_single_seg = skip_single_seg self.pick_for_pairwise = pick_for_pairwise self.info_dir = self.out_dir + "_info" exist_or_mkdir(self.info_dir) def work(self, job_id): qids = self.query_group[job_id] max_data_per_job = 1000 * 1000 base = job_id * max_data_per_job data_id_manager = DataIDManager(base, base+max_data_per_job) output_path = os.path.join(self.out_dir, str(job_id)) writer = RecordWriterWrap(output_path) for qid in qids: try: sr_per_qid = self.seg_resource_loader.load_for_qid(qid) docs_to_predict = select_one_pos_neg_doc(sr_per_qid.sr_per_query_doc) for sr_per_doc in docs_to_predict: label_id = sr_per_doc.label if self.skip_single_seg and len(sr_per_doc.segs) == 1: continue for seg_idx, seg in enumerate(sr_per_doc.segs): info = { 'qid': qid, 'doc_id': sr_per_doc.doc_id, 'seg_idx': seg_idx } data_id = data_id_manager.assign(info) feature = encode_sr(seg, self.max_seq_length, label_id, data_id) writer.write_feature(feature) except FileNotFoundError: if qid in missing_qids: pass else: raise writer.close() info_save_path = os.path.join(self.info_dir, "{}.info".format(job_id)) json.dump(data_id_manager.id_to_info, open(info_save_path, "w"))
class BestSegTrainGenPairwise: def __init__(self, max_seq_length, best_seg_collector: BestSegCollector, out_dir): self.query_group: List[List[QueryID]] = load_query_group("train") self.seg_resource_loader = SegmentResourceLoader(job_man_dir, "train") self.max_seq_length = max_seq_length self.out_dir = out_dir self.info_dir = self.out_dir + "_info" self.best_seg_collector = best_seg_collector exist_or_mkdir(self.info_dir) @staticmethod def pool_pos_neg_doc(doc_ids, sr_per_qid: SRPerQuery) \ -> Tuple[SRPerQueryDoc, SRPerQueryDoc]: doc_itr = [doc for doc in sr_per_qid.sr_per_query_doc if doc.doc_id in doc_ids] return select_one_pos_neg_doc(doc_itr, SRPerQueryDoc.get_label) def work(self, job_id): qid_to_max_seg_idx: Dict[str, Dict[str, int]] = self.best_seg_collector.get_best_seg_info_2d(job_id) qids = self.query_group[job_id] output_path = os.path.join(self.out_dir, str(job_id)) writer = RecordWriterWrap(output_path) for qid in qids: sr_per_qid = self.seg_resource_loader.load_for_qid(qid) doc_ids = list(qid_to_max_seg_idx[qid].keys()) max_seg_idx_d = qid_to_max_seg_idx[qid] pos_doc, neg_doc = self.pool_pos_neg_doc(doc_ids, sr_per_qid) def get_max_seg(sr_per_doc: SRPerQueryDoc) -> SegmentRepresentation: max_seg_idx = max_seg_idx_d[sr_per_doc.doc_id] try: seg = sr_per_doc.segs[max_seg_idx] except IndexError: print('qid={} doc_id={}'.format(qid, sr_per_doc.doc_id)) print("max_seg_idx={} but len(segs)={}".format(max_seg_idx, len(sr_per_doc.segs))) raise return seg pos_seg = get_max_seg(pos_doc) neg_seg = get_max_seg(neg_doc) feature = encode_sr_pair(pos_seg, neg_seg, self.max_seq_length, ) writer.write_feature(feature) writer.close()
class BestSegTrainGen: def __init__(self, max_seq_length, best_seg_collector: BestSegCollector, out_dir): self.query_group: List[List[QueryID]] = load_query_group("train") self.seg_resource_loader = SegmentResourceLoader(job_man_dir, "train") self.max_seq_length = max_seq_length self.out_dir = out_dir self.info_dir = self.out_dir + "_info" self.best_seg_collector = best_seg_collector exist_or_mkdir(self.info_dir) def work(self, job_id): qid_to_max_seg_idx: Dict[Tuple[str, str], int] = self.best_seg_collector.get_best_seg_info(job_id) qids = self.query_group[job_id] output_path = os.path.join(self.out_dir, str(job_id)) writer = RecordWriterWrap(output_path) for qid in qids: sr_per_qid = self.seg_resource_loader.load_for_qid(qid) for sr_per_doc in sr_per_qid.sr_per_query_doc: if len(sr_per_doc.segs) == 1: continue qdid = qid, sr_per_doc.doc_id max_seg_idx = qid_to_max_seg_idx[qdid] label_id = sr_per_doc.label try: seg = sr_per_doc.segs[max_seg_idx] feature = encode_sr(seg, self.max_seq_length, label_id, ) writer.write_feature(feature) except IndexError: print('qid={} doc_id={}'.format(qid, sr_per_doc.doc_id)) print("max_seg_idx={} but len(segs)={}".format(max_seg_idx, len(sr_per_doc.segs))) raise writer.close()
def main3(): srl = SegmentResourceLoader(job_man_dir, "train") sr_per_query = srl.load_for_qid("1000633")