class Verifier: def __init__(self): self.resource = ProcessedResource("train") def work(self, job_id): qid_list = self.resource.query_group[job_id] for qid in qid_list: if qid not in self.resource.candidate_doc_d: continue target_docs = self.resource.candidate_doc_d[qid] tokens_d = self.resource.get_doc_tokens_d(qid) for doc_id in target_docs: if doc_id not in tokens_d: log_variables(qid, target_docs) print("Not foudn: ", doc_id)
def main(): split = "train" resource = ProcessedResource(split) data_id_manager = DataIDManager(0) max_seq_length = 512 basic_encoder = LeadingN(max_seq_length, 1) generator = PassageLengthInspector(resource, basic_encoder, max_seq_length) qids_all = [] for job_id in range(40): qids = resource.query_group[job_id] data_bin = 100000 data_id_st = job_id * data_bin data_id_ed = data_id_st + data_bin qids_all.extend(qids) tprint("generating instances") insts = generator.generate(data_id_manager, qids_all) generator.write(insts, "")
from typing import List, Dict from data_generator.job_runner import JobRunner from epath import job_man_dir from tlm.data_gen.adhoc_datagen import LeadingN from tlm.data_gen.msmarco_doc_gen.gen_worker import MMDWorker, PointwiseGen, \ FirstPassagePairGenerator from tlm.data_gen.msmarco_doc_gen.processed_resource import ProcessedResource, ProcessedResource10doc, \ ProcessedResource50doc if __name__ == "__main__": split = "train" resource = ProcessedResource(split) max_seq_length = 512 basic_encoder = LeadingN(max_seq_length, 1) generator = FirstPassagePairGenerator(resource, basic_encoder, max_seq_length) def factory(out_dir): return MMDWorker(resource.query_group, generator, out_dir) runner = JobRunner(job_man_dir, len(resource.query_group) - 1, "MMD_pair_first".format(split), factory) runner.start()
def __init__(self): self.resource = ProcessedResource("train")