示例#1
0
class Verifier:
    def __init__(self):
        self.resource = ProcessedResource("train")

    def work(self, job_id):
        qid_list = self.resource.query_group[job_id]
        for qid in qid_list:
            if qid not in self.resource.candidate_doc_d:
                continue

            target_docs = self.resource.candidate_doc_d[qid]
            tokens_d = self.resource.get_doc_tokens_d(qid)

            for doc_id in target_docs:
                if doc_id not in tokens_d:
                    log_variables(qid, target_docs)
                    print("Not foudn: ", doc_id)
示例#2
0
def main():
    split = "train"
    resource = ProcessedResource(split)
    data_id_manager = DataIDManager(0)
    max_seq_length = 512
    basic_encoder = LeadingN(max_seq_length, 1)
    generator = PassageLengthInspector(resource, basic_encoder, max_seq_length)

    qids_all = []
    for job_id in range(40):
        qids = resource.query_group[job_id]
        data_bin = 100000
        data_id_st = job_id * data_bin
        data_id_ed = data_id_st + data_bin
        qids_all.extend(qids)

    tprint("generating instances")
    insts = generator.generate(data_id_manager, qids_all)
    generator.write(insts, "")
示例#3
0
from typing import List, Dict

from data_generator.job_runner import JobRunner
from epath import job_man_dir
from tlm.data_gen.adhoc_datagen import LeadingN
from tlm.data_gen.msmarco_doc_gen.gen_worker import MMDWorker, PointwiseGen, \
    FirstPassagePairGenerator
from tlm.data_gen.msmarco_doc_gen.processed_resource import ProcessedResource, ProcessedResource10doc, \
    ProcessedResource50doc

if __name__ == "__main__":
    split = "train"
    resource = ProcessedResource(split)
    max_seq_length = 512
    basic_encoder = LeadingN(max_seq_length, 1)
    generator = FirstPassagePairGenerator(resource, basic_encoder,
                                          max_seq_length)

    def factory(out_dir):
        return MMDWorker(resource.query_group, generator, out_dir)

    runner = JobRunner(job_man_dir,
                       len(resource.query_group) - 1,
                       "MMD_pair_first".format(split), factory)
    runner.start()
示例#4
0
 def __init__(self):
     self.resource = ProcessedResource("train")