Exemple #1
0
 def __init__(self, max_seq_length,
              out_dir):
     self.query_group: List[List[QueryID]] = load_query_group("train")
     self.seg_resource_loader = SegmentResourceLoader(job_man_dir, "train")
     self.max_seq_length = max_seq_length
     self.out_dir = out_dir
     self.info_dir = self.out_dir + "_info"
     exist_or_mkdir(self.info_dir)
Exemple #2
0
 def __init__(self,
              max_seq_length,
              split,
              skip_single_seg,
              pick_for_pairwise,
              out_dir):
     self.query_group: List[List[QueryID]] = load_query_group(split)
     self.seg_resource_loader = SegmentResourceLoader(job_man_dir, split)
     self.max_seq_length = max_seq_length
     self.out_dir = out_dir
     self.skip_single_seg = skip_single_seg
     self.pick_for_pairwise = pick_for_pairwise
     self.info_dir = self.out_dir + "_info"
     exist_or_mkdir(self.info_dir)
Exemple #3
0
def seg_resource_loader_test():
    srl = SegmentResourceLoader(job_man_dir, "train")
    qid = "1000008"
    sr_per_query: SRPerQuery = srl.load_for_qid(qid)

    assert qid == sr_per_query.qid
    for sr in sr_per_query.sr_per_query_doc:
        print(sr.doc_id)
        print(len(sr.segs))
        for s in sr.segs:
            print(s.first_seg)
            print(s.second_seg)
            break
        break
Exemple #4
0
class SingleSegTrainGen:
    def __init__(self, max_seq_length,
                 out_dir):
        self.query_group: List[List[QueryID]] = load_query_group("train")
        self.seg_resource_loader = SegmentResourceLoader(job_man_dir, "train")
        self.max_seq_length = max_seq_length
        self.out_dir = out_dir
        self.info_dir = self.out_dir + "_info"
        exist_or_mkdir(self.info_dir)

    def work(self, job_id):
        qids = self.query_group[job_id]
        output_path = os.path.join(self.out_dir, str(job_id))
        writer = RecordWriterWrap(output_path)
        for qid in qids:
            try:
                sr_per_qid = self.seg_resource_loader.load_for_qid(qid)
                for sr_per_doc in sr_per_qid.sr_per_query_doc:
                    if len(sr_per_doc.segs) > 1:
                        continue
                    label_id = sr_per_doc.label
                    seg = sr_per_doc.segs[0]
                    feature = encode_sr(seg,
                                        self.max_seq_length,
                                        label_id,
                                        )
                    writer.write_feature(feature)
            except FileNotFoundError:
                if qid in missing_qids:
                    pass
                else:
                    raise
        writer.close()
Exemple #5
0
class BestSegmentPredictionGen:
    def __init__(self,
                 max_seq_length,
                 split,
                 skip_single_seg,
                 pick_for_pairwise,
                 out_dir):
        self.query_group: List[List[QueryID]] = load_query_group(split)
        self.seg_resource_loader = SegmentResourceLoader(job_man_dir, split)
        self.max_seq_length = max_seq_length
        self.out_dir = out_dir
        self.skip_single_seg = skip_single_seg
        self.pick_for_pairwise = pick_for_pairwise
        self.info_dir = self.out_dir + "_info"
        exist_or_mkdir(self.info_dir)

    def work(self, job_id):
        qids = self.query_group[job_id]
        max_data_per_job = 1000 * 1000
        base = job_id * max_data_per_job
        data_id_manager = DataIDManager(base, base+max_data_per_job)
        output_path = os.path.join(self.out_dir, str(job_id))
        writer = RecordWriterWrap(output_path)

        for qid in qids:
            try:
                sr_per_qid = self.seg_resource_loader.load_for_qid(qid)
                docs_to_predict = select_one_pos_neg_doc(sr_per_qid.sr_per_query_doc)
                for sr_per_doc in docs_to_predict:
                    label_id = sr_per_doc.label
                    if self.skip_single_seg and len(sr_per_doc.segs) == 1:
                        continue
                    for seg_idx, seg in enumerate(sr_per_doc.segs):
                        info = {
                            'qid': qid,
                            'doc_id': sr_per_doc.doc_id,
                            'seg_idx': seg_idx
                        }
                        data_id = data_id_manager.assign(info)
                        feature = encode_sr(seg,
                                            self.max_seq_length,
                                            label_id,
                                            data_id)
                        writer.write_feature(feature)
            except FileNotFoundError:
                if qid in missing_qids:
                    pass
                else:
                    raise

        writer.close()
        info_save_path = os.path.join(self.info_dir, "{}.info".format(job_id))
        json.dump(data_id_manager.id_to_info, open(info_save_path, "w"))
Exemple #6
0
class BestSegTrainGenPairwise:
    def __init__(self, max_seq_length,
                 best_seg_collector: BestSegCollector,
                 out_dir):
        self.query_group: List[List[QueryID]] = load_query_group("train")
        self.seg_resource_loader = SegmentResourceLoader(job_man_dir, "train")
        self.max_seq_length = max_seq_length
        self.out_dir = out_dir
        self.info_dir = self.out_dir + "_info"
        self.best_seg_collector = best_seg_collector
        exist_or_mkdir(self.info_dir)

    @staticmethod
    def pool_pos_neg_doc(doc_ids, sr_per_qid: SRPerQuery) \
            -> Tuple[SRPerQueryDoc, SRPerQueryDoc]:
        doc_itr = [doc for doc in sr_per_qid.sr_per_query_doc if doc.doc_id in doc_ids]
        return select_one_pos_neg_doc(doc_itr, SRPerQueryDoc.get_label)

    def work(self, job_id):
        qid_to_max_seg_idx: Dict[str, Dict[str, int]] = self.best_seg_collector.get_best_seg_info_2d(job_id)
        qids = self.query_group[job_id]
        output_path = os.path.join(self.out_dir, str(job_id))
        writer = RecordWriterWrap(output_path)
        for qid in qids:
            sr_per_qid = self.seg_resource_loader.load_for_qid(qid)
            doc_ids = list(qid_to_max_seg_idx[qid].keys())
            max_seg_idx_d = qid_to_max_seg_idx[qid]
            pos_doc, neg_doc = self.pool_pos_neg_doc(doc_ids, sr_per_qid)

            def get_max_seg(sr_per_doc: SRPerQueryDoc) -> SegmentRepresentation:
                max_seg_idx = max_seg_idx_d[sr_per_doc.doc_id]
                try:
                    seg = sr_per_doc.segs[max_seg_idx]
                except IndexError:
                    print('qid={} doc_id={}'.format(qid, sr_per_doc.doc_id))
                    print("max_seg_idx={} but len(segs)={}".format(max_seg_idx, len(sr_per_doc.segs)))
                    raise
                return seg

            pos_seg = get_max_seg(pos_doc)
            neg_seg = get_max_seg(neg_doc)
            feature = encode_sr_pair(pos_seg,
                                     neg_seg,
                                     self.max_seq_length,
                                     )
            writer.write_feature(feature)
        writer.close()
Exemple #7
0
class BestSegTrainGen:
    def __init__(self, max_seq_length,
                 best_seg_collector: BestSegCollector,
                 out_dir):
        self.query_group: List[List[QueryID]] = load_query_group("train")
        self.seg_resource_loader = SegmentResourceLoader(job_man_dir, "train")
        self.max_seq_length = max_seq_length
        self.out_dir = out_dir
        self.info_dir = self.out_dir + "_info"
        self.best_seg_collector = best_seg_collector
        exist_or_mkdir(self.info_dir)

    def work(self, job_id):
        qid_to_max_seg_idx: Dict[Tuple[str, str], int] = self.best_seg_collector.get_best_seg_info(job_id)
        qids = self.query_group[job_id]
        output_path = os.path.join(self.out_dir, str(job_id))
        writer = RecordWriterWrap(output_path)
        for qid in qids:
            sr_per_qid = self.seg_resource_loader.load_for_qid(qid)
            for sr_per_doc in sr_per_qid.sr_per_query_doc:
                if len(sr_per_doc.segs) == 1:
                    continue
                qdid = qid, sr_per_doc.doc_id
                max_seg_idx = qid_to_max_seg_idx[qdid]
                label_id = sr_per_doc.label
                try:
                    seg = sr_per_doc.segs[max_seg_idx]
                    feature = encode_sr(seg,
                                        self.max_seq_length,
                                        label_id,
                                        )
                    writer.write_feature(feature)
                except IndexError:
                    print('qid={} doc_id={}'.format(qid, sr_per_doc.doc_id))
                    print("max_seg_idx={} but len(segs)={}".format(max_seg_idx, len(sr_per_doc.segs)))
                    raise

        writer.close()
Exemple #8
0
def main3():
    srl = SegmentResourceLoader(job_man_dir, "train")
    sr_per_query = srl.load_for_qid("1000633")