示例#1
0
def write_qck_as_tfrecord(save_path, payloads: Iterable[QCKCompactEntry]):
    data_id_man = DataIDManager(0, 1000 * 1000)

    tokenizer = get_tokenizer()
    cache_tokenizer = CachedTokenizer(tokenizer)
    max_seq_length = 512

    def encode_fn(e: QCKCompactEntry) -> OrderedDict:
        query, candidate, qk_out_entry = e
        candidate: QCKCandidate = candidate
        info = {
            'query': query,
            'candidate': candidate,
            'kdp': qk_out_entry.kdp
        }

        p = PayloadAsTokens(passage=qk_out_entry.passage_tokens,
                            text1=cache_tokenizer.tokenize(query.text),
                            text2=cache_tokenizer.tokenize(candidate.text),
                            data_id=data_id_man.assign(info),
                            is_correct=0
                            )
        return encode_two_inputs(max_seq_length, tokenizer, p)

    write_records_w_encode_fn(save_path, encode_fn, payloads)
    return data_id_man
示例#2
0
def make_cppnc_problem(passage_score_path: FilePath, data_id_to_info: Dict,
                       claims: List[Dict], candidate_perspectives, config,
                       save_name: str, encode_inner_fn) -> None:
    output: List[Tuple[int, List[Dict]]] = collect_good_passages(
        data_id_to_info, passage_score_path, config)
    joined_payloads: List = list(
        join_perspective(output, candidate_perspectives))
    tokenizer = get_tokenizer()
    data_id_man = DataIDManager()

    payloads: Iterable[PayloadAsTokens] = put_texts(joined_payloads, claims,
                                                    tokenizer, data_id_man)
    max_seq_length = 512

    def encode_fn(r: PayloadAsTokens):
        return encode_inner_fn(max_seq_length, tokenizer, r)

    out_dir = os.path.join(output_path, "cppnc")
    exist_or_mkdir(out_dir)
    save_path = os.path.join(out_dir, save_name + ".tfrecord")
    write_records_w_encode_fn(save_path, encode_fn, payloads)
    info_save_path = os.path.join(out_dir, save_name + ".info")
    print("Payload size : ", len(data_id_man.id_to_info))

    json.dump(data_id_man.id_to_info, open(info_save_path, "w"))
    print("tfrecord saved at :", save_path)
    print("info saved at :", info_save_path)
示例#3
0
def make_cppnc_dummy_problem(claims: List[Dict], candidate_perspectives,
                             save_name: str, encode_inner_fn) -> None:

    empty_passage = {'passage': []}

    def get_payload() -> Iterable[Tuple[int, int, List[Dict]]]:
        for cid, candidates in candidate_perspectives.items():
            for candi in candidates:
                yield cid, candi['pid'], [empty_passage]

    tokenizer = get_tokenizer()
    data_id_man = DataIDManager()

    payloads: Iterable[PayloadAsTokens] = put_texts(get_payload(), claims,
                                                    tokenizer, data_id_man)
    max_seq_length = 512

    def encode_fn(r: PayloadAsTokens):
        return encode_inner_fn(max_seq_length, tokenizer, r)

    out_dir = os.path.join(output_path, "cppnc")
    exist_or_mkdir(out_dir)
    save_path = os.path.join(out_dir, save_name + ".tfrecord")
    write_records_w_encode_fn(save_path, encode_fn, payloads)
    info_save_path = os.path.join(out_dir, save_name + ".info")
    print("Payload size : ", len(data_id_man.id_to_info))

    json.dump(data_id_man.id_to_info, open(info_save_path, "w"))
    print("tfrecord saved at :", save_path)
    print("info saved at :", info_save_path)
示例#4
0
    def generate_selected_training_data_ablation_only_pos(info, key, max_seq_length, save_dir, score_dir):
        data_id_manager = DataIDManager(0, 1000000)
        out_path = os.path.join(save_dir, str(key))
        pred_path = os.path.join(score_dir, str(key))
        tprint("data gen")
        itr = enum_best_segments(pred_path, info)
        insts = []
        for selected_entry in itr:
            selected = decompress_seg_ids_entry(selected_entry)
            assert len(selected['input_ids']) == len(selected['seg_ids'])

            selected['input_ids'] = pad0(selected['input_ids'], max_seq_length)
            selected['seg_ids'] = pad0(selected['seg_ids'], max_seq_length)
            # data_id = data_id_manager.assign(selected_segment.to_info_d())
            data_id = 0
            ci = InstAsInputIds(
                selected['input_ids'],
                selected['seg_ids'],
                selected['label'],
                data_id)
            insts.append(ci)

        def encode_fn(inst: InstAsInputIds) -> collections.OrderedDict:
            return encode_inst_as_input_ids(max_seq_length, inst)

        tprint("writing")
        write_records_w_encode_fn(out_path, encode_fn, insts, len(insts))
        save_info(save_dir, data_id_manager, str(key) + ".info")
示例#5
0
def sentence_payload_gen(q_res_path: str, top_n, data_id_man: DataIDManager):
    print("loading ranked list")
    ranked_list: Dict[
        str, List[SimpleRankedListEntry]] = load_galago_ranked_list(q_res_path)
    qid_list = list(ranked_list.keys())
    qid_list = qid_list[:10]
    ranked_list = {k: ranked_list[k] for k in qid_list}
    print("Pre loading docs")
    preload_docs(ranked_list, top_n)
    entries: List[Tuple[str, bool, int]] = []

    def enum_sentence(tokens) -> Iterator[str]:
        text = " ".join(tokens)
        sents = sent_tokenize(text)
        yield from sents

    ticker = TimeEstimator(len(ranked_list))
    for qid in ranked_list:
        q_res: List[SimpleRankedListEntry] = ranked_list[qid]
        docs = iterate_docs(q_res, top_n)

        for doc in docs:
            for sent_idx, sent in enumerate(enum_sentence(doc.tokens)):
                info = {
                    'doc_id': doc.doc_id,
                    'sent_idx': sent_idx,
                    'sentence': sent
                }
                data_id = data_id_man.assign(info)
                e = sent, True, data_id
                entries.append(e)

        ticker.tick()
    return entries
示例#6
0
def generate(pos_doc_ids, all_doc_list, max_seq_length) -> List[Instance]:
    # load list of documents
    # make list of negative documents.
    # remove duplicates.
    seq_length = max_seq_length - 2
    neg_docs_ids = list([d for d in all_doc_list if d not in pos_doc_ids])
    pos_docs: List[List[List[str]]] = load_multiple(BertTokenizedCluewebDoc,
                                                    pos_doc_ids, True)
    hashes = lmap(doc_hash, pos_docs)
    duplicate_indice = get_duplicate_list(hashes)
    pos_docs: List[List[List[str]]] = list(
        [doc for i, doc in enumerate(pos_docs) if i not in duplicate_indice])
    neg_docs: List[List[List[str]]] = load_multiple_divided(
        BertTokenizedCluewebDoc, neg_docs_ids, True)

    data_id_man = DataIDManager()

    def enum_instances(doc_list: List[List[List[str]]],
                       label: int) -> Iterator[Instance]:
        for d in doc_list:
            for passage in enum_passages(d, seq_length):
                yield Instance(passage, data_id_man.assign([]), label)

    pos_insts = list(enum_instances(pos_docs, 1))
    neg_insts = list(enum_instances(neg_docs, 0))
    all_insts = pos_insts + neg_insts
    print("{} instances".format(len(all_insts)))
    random.shuffle(all_insts)
    return all_insts
示例#7
0
文件: qc_gen.py 项目: clover3/Chair
def do_generate_jobs(candidate_dict, is_correct_fn, save_dir, split):
    queries = get_qck_queries(split)
    generator = QCInstanceGenerator(candidate_dict, is_correct_fn)
    data_id_manager = DataIDManager()
    insts = generator.generate(queries, data_id_manager)
    save_path = os.path.join(save_dir, split)
    write_records_w_encode_fn(save_path, generator.encode_fn, insts)
    json.dump(data_id_manager.id_to_info, open(save_path + ".info", "w"))
示例#8
0
def main():
    data_id_manager = DataIDManager()
    data = []
    for text in enum_f5_data():
        info = {
            'text': text,
        }
        data_id = data_id_manager.assign(info)
        label = 0
        data.append(TextInstance(text, label, data_id))

    encode_fn = get_encode_fn_w_data_id(512, False)
    save_path = at_output_dir("clue_counter_arg", "clue_f5.tfrecord")
    write_records_w_encode_fn(save_path, encode_fn, data)

    info_save_path = at_output_dir("clue_counter_arg", "clue_f5.tfrecord.info")
    json.dump(data_id_manager.id_to_info, open(info_save_path, "w"))
示例#9
0
def main():
    raw_payload: List[ClaimPassages] = load_dev_payload()
    save_path = os.path.join(output_path, "pc_dev_passage_payload")
    encode = get_encode_fn(512)
    data_id_manage = DataIDManager()
    insts = list(generate_instances(raw_payload, data_id_manage))
    write_records_w_encode_fn(save_path, encode, insts, len(insts))
    save_to_pickle(data_id_manage.id_to_info, "pc_dev_passage_payload_info")
示例#10
0
def make_pc_qc(queries: Iterable[QCKQuery],
               eval_candidate: Dict[str, List[QCKCandidate]], is_correct_fn,
               save_path: str):
    generator = QCInstanceGenerator(eval_candidate, is_correct_fn)
    data_id_manager = DataIDManager(0, 10000 * 10000)
    insts = generator.generate(queries, data_id_manager)
    insts = list(insts)
    write_records_w_encode_fn(save_path, generator.encode_fn, insts)
    json.dump(data_id_manager.id_to_info, open(save_path + ".info", "w"))
示例#11
0
def collect_info_transform(data: Iterable[Tuple[QCKQuery, QCKCandidate, bool]], data_id_man: DataIDManager) \
        -> Iterable[QCInstance]:
    for query, candidate, is_correct in data:
        info = {
            'query': get_light_qckquery(query),
            'candidate': get_light_qckcandidate(candidate)
        }
        yield QCInstance(query.text, candidate.text, data_id_man.assign(info),
                         int(is_correct))
示例#12
0
def generate_and_write(file_name, generate_fn, tokenizer):
    data_id_man = DataIDManager()
    inst_list = generate_fn(data_id_man)
    max_seq_length = 300
    save_path = at_output_dir("alamri_tfrecord", file_name)
    encode_fn = get_encode_fn(max_seq_length, tokenizer)
    write_records_w_encode_fn(save_path, encode_fn, inst_list)
    info_save_path = at_output_dir("alamri_tfrecord", file_name + ".info")
    json.dump(data_id_man.id_to_info, open(info_save_path, "w"))
示例#13
0
def main():
    exist_or_mkdir(os.path.join(output_path, "alamri_tfrecord"))

    data_id_manager = DataIDManager()
    entries = []
    for claim1, claim2 in enum_true_instance():
        entries.append((claim1.text, claim2.text))

    save_path = at_output_dir("alamri_pilot", "true_pairs_all.csv")
    csv_writer = csv.writer(open(save_path, "w", newline='', encoding="utf-8"))
    foreach(csv_writer.writerow, entries)
示例#14
0
    def generate(self, data_id_manager: DataIDManager,
                 query_list) -> List[ClassificationInstanceWDataID]:
        neg_k = 1000
        all_insts = []
        for query_id in query_list:
            if query_id not in self.judgement:
                continue

            judgement = self.judgement[query_id]
            query = self.queries[query_id]
            query_tokens = self.tokenizer.tokenize(query)

            ranked_list = self.galago_rank[query_id]
            ranked_list = ranked_list[:neg_k]

            target_docs = set(judgement.keys())
            target_docs.update([e.doc_id for e in ranked_list])
            print("Total of {} docs".format(len(target_docs)))

            for doc_id in target_docs:
                tokens = self.data[doc_id]
                insts: List[Tuple[List, List]] = self.encoder.encode(
                    query_tokens, tokens)
                label = 1 if doc_id in judgement and judgement[
                    doc_id] > 0 else 0
                if label:
                    passage_scores = list([
                        self.scores[query_id, doc_id, idx]
                        for idx, _ in enumerate(insts)
                    ])
                    target_indices = self.get_target_indices(passage_scores)
                else:
                    target_indices = [0]
                    n = len(insts)
                    if random.random() < 0.1 and n > 1:
                        idx = random.randint(1, n - 1)
                        target_indices.append(idx)

                for passage_idx in target_indices:
                    tokens_seg, seg_ids = insts[passage_idx]
                    assert type(tokens_seg[0]) == str
                    assert type(seg_ids[0]) == int
                    data_id = data_id_manager.assign({
                        'doc_id': doc_id,
                        'passage_idx': passage_idx,
                        'label': label,
                        'tokens': tokens_seg,
                        'seg_ids': seg_ids,
                    })
                    all_insts.append(
                        ClassificationInstanceWDataID(tokens_seg, seg_ids,
                                                      label, data_id))

        return all_insts
示例#15
0
def write_qc_records(output_path, qc_records):
    data_id_man = DataIDManager()
    instances = collect_info_transform(qc_records, data_id_man)
    tokenizer = get_tokenizer()
    max_seq_length = 512

    def encode_fn(inst: QCInstance):
        return encode(tokenizer, max_seq_length, inst)

    write_records_w_encode_fn(output_path, encode_fn, instances)
    json.dump(data_id_man.id_to_info, open(output_path + ".info", "w"))
示例#16
0
    def work(self, job_id):
        base = job_id * 10000
        data_id_manager = DataIDManager(base)
        insts: List = self.generator.generate_instances(
            self.claims[job_id], data_id_manager)
        print("{} instances".format(len(insts)))
        self.writer(insts, self.max_seq_length,
                    os.path.join(self.out_dir, str(job_id)))

        info_dir = self.out_dir + "_info"
        exist_or_mkdir(info_dir)
        info_path = os.path.join(info_dir, str(job_id) + ".info")
        json.dump(data_id_manager.id_to_info, open(info_path, "w"))
示例#17
0
    def work(self, job_id):
        qid = self.qid_list[job_id]
        out_path = os.path.join(self.out_path, str(job_id))
        max_inst_per_job = 1000 * 1000
        base = job_id * max_inst_per_job
        data_id_manager = DataIDManager(base, max_inst_per_job)
        insts = self.gen.generate([str(qid)], data_id_manager)
        self.gen.write(insts, out_path)

        info_dir = self.out_path + "_info"
        exist_or_mkdir(info_dir)
        info_path = os.path.join(info_dir, str(job_id) + ".info")
        json.dump(data_id_manager.id_to_info, open(info_path, "w"))
示例#18
0
    def work(self, job_id):
        qids = self.query_group[job_id]
        data_bin = 100000
        data_id_st = job_id * data_bin
        data_id_ed = data_id_st + data_bin
        data_id_manager = DataIDManager(data_id_st, data_id_ed)
        tprint("generating instances")
        insts = self.generator.generate(data_id_manager, qids)
        # tprint("{} instances".format(len(insts)))
        out_path = os.path.join(self.out_dir, str(job_id))
        self.generator.write(insts, out_path)

        info_path = os.path.join(self.info_dir, "{}.info".format(job_id))
        json.dump(data_id_manager.id_to_info, open(info_path, "w"))
示例#19
0
def main():
    data_id_man = DataIDManager()
    q_res_path = sys.argv[1]
    save_path = sys.argv[2]
    max_seq_length = 512
    tokenizer = get_tokenizer()
    insts = sentence_payload_gen(q_res_path, 100, data_id_man)

    def encode_fn(t: Tuple[str, bool, int]) -> OrderedDict:
        return encode_w_data_id(tokenizer, max_seq_length, t)

    write_records_w_encode_fn(save_path, encode_fn, insts)
    json_save_path = save_path + ".info"
    json.dump(data_id_man.id_to_info, open(json_save_path, "w"))
示例#20
0
    def work(self, job_id):
        st, ed = robust_query_intervals[job_id]
        out_path = os.path.join(self.out_path, str(st))

        query_list = [str(i) for i in range(st, ed+1)]
        unit = 1000 * 1000
        base = unit * job_id
        end = base + unit
        data_id_manager = DataIDManager(base, end)
        insts = self.gen.generate(data_id_manager, query_list)
        self.gen.write(insts, out_path)

        info_path = os.path.join(self.info_dir, str(st))
        json.dump(data_id_manager.id_to_info, open(info_path, "w"))
示例#21
0
    def work(self, job_id):
        st, ed = robust_query_intervals[job_id]
        out_path = os.path.join(self.out_path, str(st))
        max_inst_per_job = 1000 * 10000
        base = job_id * max_inst_per_job
        data_id_manager = DataIDManager(base, max_inst_per_job)
        query_list = [str(i) for i in range(st, ed + 1)]
        insts = self.gen.generate(query_list, data_id_manager)
        self.gen.write(insts, out_path)

        info_dir = self.out_path + "_info"
        exist_or_mkdir(info_dir)
        info_path = os.path.join(info_dir, str(job_id) + ".info")
        json.dump(data_id_manager.id_to_info, open(info_path, "w"))
示例#22
0
    def work(self, job_id):
        qids = self.query_group[job_id]
        max_data_per_job = 1000 * 1000
        base = job_id * max_data_per_job
        data_id_manager = DataIDManager(base, base+max_data_per_job)
        output_path = os.path.join(self.out_dir, str(job_id))
        writer = RecordWriterWrap(output_path)

        for qid in qids:
            try:
                sr_per_qid = self.seg_resource_loader.load_for_qid(qid)
                docs_to_predict = select_one_pos_neg_doc(sr_per_qid.sr_per_query_doc)
                for sr_per_doc in docs_to_predict:
                    label_id = sr_per_doc.label
                    if self.skip_single_seg and len(sr_per_doc.segs) == 1:
                        continue
                    for seg_idx, seg in enumerate(sr_per_doc.segs):
                        info = {
                            'qid': qid,
                            'doc_id': sr_per_doc.doc_id,
                            'seg_idx': seg_idx
                        }
                        data_id = data_id_manager.assign(info)
                        feature = encode_sr(seg,
                                            self.max_seq_length,
                                            label_id,
                                            data_id)
                        writer.write_feature(feature)
            except FileNotFoundError:
                if qid in missing_qids:
                    pass
                else:
                    raise

        writer.close()
        info_save_path = os.path.join(self.info_dir, "{}.info".format(job_id))
        json.dump(data_id_manager.id_to_info, open(info_save_path, "w"))
示例#23
0
文件: kdp_para.py 项目: clover3/Chair
 def work(self, job_id):
     cid = self.cids[job_id]
     entries: List[SimpleRankedListEntry] = self.ranked_list[str(cid)]
     max_items = 1000 * 1000
     base = job_id * max_items
     end = base + max_items
     data_id_manager = DataIDManager(base, end)
     insts = self.get_instances(cid, data_id_manager, entries)
     save_path = os.path.join(self.out_dir, str(job_id))
     writer = self.writer
     write_records_w_encode_fn(save_path, writer.encode, insts)
     info_dir = self.out_dir + "_info"
     exist_or_mkdir(info_dir)
     info_path = os.path.join(info_dir, str(job_id) + ".info")
     json.dump(data_id_manager.id_to_info, open(info_path, "w"))
示例#24
0
 def make_tfrecord(self, job_id: int):
     save_path = os.path.join(self.request_dir, str(job_id))
     kdp_list = pickle.load(open(save_path, "rb"))
     data_id_manager = DataIDManager(0, 1000 * 1000)
     print("{} kdp".format(len(kdp_list)))
     insts = self.qck_generator.generate(kdp_list, data_id_manager)
     record_save_path = os.path.join(self.tf_record_dir, str(job_id))
     write_records_w_encode_fn(record_save_path,
                               self.qck_generator.encode_fn, insts)
     # Save for backup
     info_save_path = os.path.join(self.tf_record_dir,
                                   "{}.info".format(job_id))
     pickle.dump(data_id_manager.id_to_info, open(info_save_path, "wb"))
     # launch estimator
     add_estimator_job(job_id)
示例#25
0
    def work(self, job_id):
        max_data_per_job = 1000 * 1000
        base = job_id * max_data_per_job
        data_id_manager = DataIDManager(base, base + max_data_per_job)
        todo = self.qk_candidate[job_id:job_id + 1]
        tprint("Generating instances")
        insts: List = self.generator.generate(todo, data_id_manager)
        tprint("{} instances".format(len(insts)))
        save_path = os.path.join(self.out_dir, str(job_id))
        tprint("Writing")
        write_records_w_encode_fn(save_path, self.generator.encode_fn, insts)
        tprint("writing done")

        info_dir = self.out_dir + "_info"
        exist_or_mkdir(info_dir)
        info_path = os.path.join(info_dir, str(job_id) + ".info")
        json.dump(data_id_manager.id_to_info, open(info_path, "w"))
示例#26
0
def generate_selected_training_data_w_json(info, max_seq_length, save_dir,
                                           get_score_fn, max_seg):
    data_id_manager = DataIDManager(0, 1000000)
    tprint("data gen")

    def get_query_id_group(query_id):
        for st, ed in robust_query_intervals:
            if st <= int(query_id) <= ed:
                return st

        assert False

    tokenizer = get_tokenizer()
    for data_id, e in info.items():
        input_ids = tokenizer.convert_tokens_to_ids(e['tokens'])
        e['input_ids'] = input_ids

    maybe_num_insts = int(len(info) / 4)
    ticker = TimeEstimator(maybe_num_insts)
    itr = enum_best_segments(get_score_fn, info, max_seg)
    insts = collections.defaultdict(list)
    for selected_entry in itr:
        ticker.tick()
        selected = selected_entry
        query_id = selected['query_id']
        q_group = get_query_id_group(query_id)
        assert len(selected['tokens']) == len(selected['seg_ids'])
        input_ids = tokenizer.convert_tokens_to_ids(selected['tokens'])
        selected['input_ids'] = pad0(input_ids, max_seq_length)
        selected['seg_ids'] = pad0(selected['seg_ids'], max_seq_length)
        # data_id = data_id_manager.assign(selected_segment.to_info_d())
        data_id = 0
        ci = InstAsInputIds(selected['input_ids'], selected['seg_ids'],
                            selected['label'], data_id)
        insts[q_group].append(ci)

    def encode_fn(inst: InstAsInputIds) -> collections.OrderedDict:
        return encode_inst_as_input_ids(max_seq_length, inst)

    tprint("writing")
    for q_group, insts_per_group in insts.items():
        out_path = os.path.join(save_dir, str(q_group))
        write_records_w_encode_fn(out_path, encode_fn, insts_per_group,
                                  len(insts_per_group))
        save_info(save_dir, data_id_manager, str(q_group) + ".info")
示例#27
0
def main():
    split = "train"
    resource = ProcessedResource(split)
    data_id_manager = DataIDManager(0)
    max_seq_length = 512
    basic_encoder = LeadingN(max_seq_length, 1)
    generator = PassageLengthInspector(resource, basic_encoder, max_seq_length)

    qids_all = []
    for job_id in range(40):
        qids = resource.query_group[job_id]
        data_bin = 100000
        data_id_st = job_id * data_bin
        data_id_ed = data_id_st + data_bin
        qids_all.extend(qids)

    tprint("generating instances")
    insts = generator.generate(data_id_manager, qids_all)
    generator.write(insts, "")
示例#28
0
    def work(self, job_id):
        data_id_man = DataIDManager()
        insts = self.generate_instances(job_id, data_id_man)
        save_path = os.path.join(self.out_dir, str(job_id))

        def encode_fn(inst: Instance):
            tokens1 = inst.tokens1
            max_seg2_len = self.max_seq_length - 3 - len(tokens1)

            tokens2 = inst.tokens2[:max_seg2_len]
            tokens = ["[CLS]"] + tokens1 + ["[SEP]"] + tokens2 + ["[SEP]"]

            segment_ids = [0] * (len(tokens1) + 2) \
                          + [1] * (len(tokens2) + 1)
            tokens = tokens[:self.max_seq_length]
            segment_ids = segment_ids[:self.max_seq_length]
            features = get_basic_input_feature(self.tokenizer, self.max_seq_length, tokens, segment_ids)
            features['label_ids'] = create_int_feature([inst.label])
            features['data_id'] = create_int_feature([inst.data_id])
            return features

        write_records_w_encode_fn(save_path, encode_fn, insts)
        info_save_path = os.path.join(self.info_out_dir, str(job_id))
        json.dump(data_id_man.id_to_info, open(info_save_path, "w"))
示例#29
0
def do_job(input_dir, output_dir, info_dir, label_info_path, max_entries,
           job_id):

    exist_or_mkdir(output_dir)
    info_output_dir = output_dir + "_info"
    exist_or_mkdir(info_output_dir)

    label_info: List[Tuple[str, str,
                           int]] = json.load(open(label_info_path, "r"))
    label_info_d = {(str(a), str(b)): c for a, b, c in label_info}

    pred_path = os.path.join(input_dir, str(job_id) + ".score")
    #info_path = os.path.join(info_dir, str(job_id) + ".info")
    info_path = info_dir
    output_path = os.path.join(output_dir, str(job_id))
    info_output_path = os.path.join(info_output_dir, str(job_id))
    info = load_combine_info_jsons(info_path, qck_convert_map, True)
    fetch_field_list = ["vector", "data_id"]

    predictions = join_prediction_with_info(pred_path, info, fetch_field_list)

    def get_qid(entry):
        return entry['query'].query_id

    def get_candidate_id(entry):
        return entry['candidate'].id

    def pair_id(entry) -> Tuple[str, str]:
        return get_qid(entry), get_candidate_id(entry)

    groups: Dict[Tuple[str, str], List[Dict]] = group_by(predictions, pair_id)

    def get_new_entry(entries: List[Dict]):
        if not entries:
            return None
        vectors: Vectors = list([e['vector'] for e in entries])
        key = pair_id(entries[0])
        if key in label_info_d:
            label: Label = label_info_d[key]
        else:
            label: Label = 0

        return vectors, label

    g2: Dict[Tuple[str, str],
             Tuple[Vectors, Label]] = dict_value_map(get_new_entry, groups)
    base = 100 * 1000 * job_id
    max_count = 100 * 1000 * (job_id + 1)
    data_id_manager = DataIDManager(base, max_count)

    def get_out_itr() -> Iterable[Tuple[int, Tuple[Vectors, Label]]]:
        for key, data in g2.items():
            qid, cid = key
            data_info = {
                'qid': qid,
                'cid': cid,
            }
            data_id = data_id_manager.assign(data_info)
            yield data_id, data

    write_to_file(output_path, get_out_itr(), max_entries)
    json.dump(data_id_manager.id_to_info, open(info_output_path, "w"))
示例#30
0
def write_records_w_encode_fn_mt(
    output_path,
    encode,
    records,
):
    writer = RecordWriterWrap(output_path)
    tprint("Making records")
    features_list: List[OrderedDict] = list(map(encode, records))
    tprint("Total of {} records".format(len(features_list)))
    writer.write_feature_list(features_list)
    writer.close()

    tprint("Done writing")


if __name__ == "__main__":
    data_id_manager = DataIDManager(0, 1000 * 1000)
    job_id = 4
    request_dir = os.environ["request_dir"]
    save_path = os.path.join(request_dir, str(job_id))
    kdp_list = pickle.load(open(save_path, "rb"))
    kdp_list = kdp_list[:2]
    qck_generator: QCKGenDynamicKDP = get_qck_gen_dynamic_kdp()
    tf_record_dir = os.environ["tf_record_dir"]

    print("{} kdp".format(len(kdp_list)))
    insts = qck_generator.generate(kdp_list, data_id_manager)
    record_save_path = os.path.join(tf_record_dir, str(job_id) + ".test")
    write_records_w_encode_fn_mt(record_save_path, qck_generator.encode_fn,
                                 insts)