示例#1
0
def sentence_payload_gen(q_res_path: str, top_n, data_id_man: DataIDManager):
    print("loading ranked list")
    ranked_list: Dict[
        str, List[SimpleRankedListEntry]] = load_galago_ranked_list(q_res_path)
    qid_list = list(ranked_list.keys())
    qid_list = qid_list[:10]
    ranked_list = {k: ranked_list[k] for k in qid_list}
    print("Pre loading docs")
    preload_docs(ranked_list, top_n)
    entries: List[Tuple[str, bool, int]] = []

    def enum_sentence(tokens) -> Iterator[str]:
        text = " ".join(tokens)
        sents = sent_tokenize(text)
        yield from sents

    ticker = TimeEstimator(len(ranked_list))
    for qid in ranked_list:
        q_res: List[SimpleRankedListEntry] = ranked_list[qid]
        docs = iterate_docs(q_res, top_n)

        for doc in docs:
            for sent_idx, sent in enumerate(enum_sentence(doc.tokens)):
                info = {
                    'doc_id': doc.doc_id,
                    'sent_idx': sent_idx,
                    'sentence': sent
                }
                data_id = data_id_man.assign(info)
                e = sent, True, data_id
                entries.append(e)

        ticker.tick()
    return entries
示例#2
0
def collect_info_transform(data: Iterable[Tuple[QCKQuery, QCKCandidate, bool]], data_id_man: DataIDManager) \
        -> Iterable[QCInstance]:
    for query, candidate, is_correct in data:
        info = {
            'query': get_light_qckquery(query),
            'candidate': get_light_qckcandidate(candidate)
        }
        yield QCInstance(query.text, candidate.text, data_id_man.assign(info),
                         int(is_correct))
示例#3
0
    def generate(self, data_id_manager: DataIDManager,
                 query_list) -> List[ClassificationInstanceWDataID]:
        neg_k = 1000
        all_insts = []
        for query_id in query_list:
            if query_id not in self.judgement:
                continue

            judgement = self.judgement[query_id]
            query = self.queries[query_id]
            query_tokens = self.tokenizer.tokenize(query)

            ranked_list = self.galago_rank[query_id]
            ranked_list = ranked_list[:neg_k]

            target_docs = set(judgement.keys())
            target_docs.update([e.doc_id for e in ranked_list])
            print("Total of {} docs".format(len(target_docs)))

            for doc_id in target_docs:
                tokens = self.data[doc_id]
                insts: List[Tuple[List, List]] = self.encoder.encode(
                    query_tokens, tokens)
                label = 1 if doc_id in judgement and judgement[
                    doc_id] > 0 else 0
                if label:
                    passage_scores = list([
                        self.scores[query_id, doc_id, idx]
                        for idx, _ in enumerate(insts)
                    ])
                    target_indices = self.get_target_indices(passage_scores)
                else:
                    target_indices = [0]
                    n = len(insts)
                    if random.random() < 0.1 and n > 1:
                        idx = random.randint(1, n - 1)
                        target_indices.append(idx)

                for passage_idx in target_indices:
                    tokens_seg, seg_ids = insts[passage_idx]
                    assert type(tokens_seg[0]) == str
                    assert type(seg_ids[0]) == int
                    data_id = data_id_manager.assign({
                        'doc_id': doc_id,
                        'passage_idx': passage_idx,
                        'label': label,
                        'tokens': tokens_seg,
                        'seg_ids': seg_ids,
                    })
                    all_insts.append(
                        ClassificationInstanceWDataID(tokens_seg, seg_ids,
                                                      label, data_id))

        return all_insts
示例#4
0
def main():
    data_id_manager = DataIDManager()
    data = []
    for text in enum_f5_data():
        info = {
            'text': text,
        }
        data_id = data_id_manager.assign(info)
        label = 0
        data.append(TextInstance(text, label, data_id))

    encode_fn = get_encode_fn_w_data_id(512, False)
    save_path = at_output_dir("clue_counter_arg", "clue_f5.tfrecord")
    write_records_w_encode_fn(save_path, encode_fn, data)

    info_save_path = at_output_dir("clue_counter_arg", "clue_f5.tfrecord.info")
    json.dump(data_id_manager.id_to_info, open(info_save_path, "w"))
示例#5
0
    def work(self, job_id):
        qids = self.query_group[job_id]
        max_data_per_job = 1000 * 1000
        base = job_id * max_data_per_job
        data_id_manager = DataIDManager(base, base+max_data_per_job)
        output_path = os.path.join(self.out_dir, str(job_id))
        writer = RecordWriterWrap(output_path)

        for qid in qids:
            try:
                sr_per_qid = self.seg_resource_loader.load_for_qid(qid)
                docs_to_predict = select_one_pos_neg_doc(sr_per_qid.sr_per_query_doc)
                for sr_per_doc in docs_to_predict:
                    label_id = sr_per_doc.label
                    if self.skip_single_seg and len(sr_per_doc.segs) == 1:
                        continue
                    for seg_idx, seg in enumerate(sr_per_doc.segs):
                        info = {
                            'qid': qid,
                            'doc_id': sr_per_doc.doc_id,
                            'seg_idx': seg_idx
                        }
                        data_id = data_id_manager.assign(info)
                        feature = encode_sr(seg,
                                            self.max_seq_length,
                                            label_id,
                                            data_id)
                        writer.write_feature(feature)
            except FileNotFoundError:
                if qid in missing_qids:
                    pass
                else:
                    raise

        writer.close()
        info_save_path = os.path.join(self.info_dir, "{}.info".format(job_id))
        json.dump(data_id_manager.id_to_info, open(info_save_path, "w"))
示例#6
0
 def generate(self, triplet_itr: Iterator[Tuple[Query, Doc, Doc]]):
     data_id_manager = DataIDManager()
     for q, d1, d2 in triplet_itr:
         data_id = data_id_manager.assign({})
         inst = self.triplet_to_paired_instance(q, d1, d2, data_id)
         yield inst