def main(): max_passage_length = 128 num_segment = 1 encoder = LeadingN(max_passage_length, num_segment) max_seq_length = max_passage_length worker_factory = partial(RobustWorker, RobustPointwiseTrainGenEx(encoder, max_seq_length, "desc")) runner = JobRunner(job_man_dir, 4, "first_128_desc", worker_factory) runner.start()
def main(): split = "train" resource = ProcessedResource(split) data_id_manager = DataIDManager(0) max_seq_length = 512 basic_encoder = LeadingN(max_seq_length, 1) generator = PassageLengthInspector(resource, basic_encoder, max_seq_length) qids_all = [] for job_id in range(40): qids = resource.query_group[job_id] data_bin = 100000 data_id_st = job_id * data_bin data_id_ed = data_id_st + data_bin qids_all.extend(qids) tprint("generating instances") insts = generator.generate(data_id_manager, qids_all) generator.write(insts, "")
from typing import List, Dict from data_generator.job_runner import JobRunner from epath import job_man_dir from tlm.data_gen.adhoc_datagen import LeadingN from tlm.data_gen.msmarco_doc_gen.gen_worker import MMDWorker, PointwiseGen, \ FirstPassagePairGenerator from tlm.data_gen.msmarco_doc_gen.processed_resource import ProcessedResource, ProcessedResource10doc, \ ProcessedResource50doc if __name__ == "__main__": split = "train" resource = ProcessedResource(split) max_seq_length = 512 basic_encoder = LeadingN(max_seq_length, 1) generator = FirstPassagePairGenerator(resource, basic_encoder, max_seq_length) def factory(out_dir): return MMDWorker(resource.query_group, generator, out_dir) runner = JobRunner(job_man_dir, len(resource.query_group) - 1, "MMD_pair_first".format(split), factory) runner.start()
from typing import List, Dict from data_generator.job_runner import JobRunner from epath import job_man_dir from tlm.data_gen.adhoc_datagen import LeadingN from tlm.data_gen.msmarco_doc_gen.gen_worker import MMDWorker, PointwiseGen from tlm.data_gen.msmarco_doc_gen.processed_resource import ProcessedResource if __name__ == "__main__": split = "train" resource = ProcessedResource(split) max_seq_length = 512 basic_encoder = LeadingN(256, 1) generator = PointwiseGen(resource, basic_encoder, max_seq_length) def factory(out_dir): return MMDWorker(resource.query_group, generator, out_dir) runner = JobRunner(job_man_dir, len(resource.query_group)-1, "MMD_train_256".format(split), factory) runner.start()
def __init__(self, max_seq_length): self.tokenizer = get_tokenizer() self.encoder = LeadingN(max_seq_length, 1) self.max_seq_length = max_seq_length