Beispiel #1
0
    def init(self, force=False):
        base_path = util.path_dataset(self)
        base = Path(base_path)

        # DOCUMENT COLLECTION
        idx = [self.index, self.index_stem, self.doc_store]
        self._init_indices_parallel(idx, self._init_doc_iter(), force)

        # TRAIN

        files = {}
        files.update({
            base / f'train-f{f}.auto.qrels': f'train/train.fold{f}.cbor.hierarchical.qrels' for f in range(5)
        })
        files.update({
            base / f'train-f{f}.queries.tsv': f'train/train.fold{f}.cbor.outlines' for f in range(5)
        })
        if force or not all(f.exists() for f in files) and self._confirm_dua():
            with util.download_tmp(_SOURCES['train'], tarf=True) as f:
                for member in f:
                    for f_out, f_in in files.items():
                        if member.name == f_in:
                            if f_out.suffix == '.qrels':
                                self._init_file_copy(f.extractfile(member), f_out, force)
                            elif f_out.suffix == '.tsv':
                                self._init_queryfile(f.extractfile(member), f_out, force)

        # TEST

        files = {
            base / 'test.queries.tsv': 'benchmarkY1test.public/test.benchmarkY1test.cbor.outlines'
        }
        if force or not all(f.exists() for f in files) and self._confirm_dua():
            with util.download_tmp(_SOURCES['test'], tarf=True) as f:
                for f_out, f_in in files.items():
                    self._init_queryfile(f.extractfile(f_in), f_out, force)

        files = {
            base / 'test.auto.qrels': 'TREC_CAR_2017_qrels/automatic.benchmarkY1test.cbor.hierarchical.qrels',
            base /'test.manual.qrels': 'TREC_CAR_2017_qrels/manual.benchmarkY1test.cbor.hierarchical.qrels',
        }
        if force or not all(f.exists() for f in files) and self._confirm_dua():
            with util.download_tmp(_SOURCES['test-qrels'], tarf=True) as f:
                for f_out, f_in in files.items():
                    self._init_file_copy(f.extractfile(f_in), f_out, force)

        # TEST200

        files = {
            base / 'test200.auto.qrels': 'test200/train.test200.cbor.hierarchical.qrels',
            base / 'test200.queries.tsv': 'test200/train.test200.cbor.outlines',
        }
        if force or not all(f.exists() for f in files) and self._confirm_dua():
            with util.download_tmp(_SOURCES['test200'], tarf=True) as f:
                for f_out, f_in in files.items():
                    if f_out.suffix == '.qrels':
                        self._init_file_copy(f.extractfile(f_in), f_out, force)
                    elif f_out.suffix == '.tsv':
                        self._init_queryfile(f.extractfile(f_in), f_out, force)
Beispiel #2
0
 def _init_doctttttquery_iter(self):
     with util.download_tmp(_SOURCES['doctttttquery-predictions'], expected_md5=_HASHES['doctttttquery-predictions']) as f1, \
          util.download_tmp(_SOURCES['collection'], expected_md5=_HASHES['collection']) as f2:
         with zipfile.ZipFile(f1) as zipf, tarfile.open(fileobj=f2) as tarf:
             collection_stream = io.TextIOWrapper(tarf.extractfile('collection.tsv'))
             d5_iter = self._init_doctttttquery_zipf_iter(zipf)
             for (did, text), d5text in self.logger.pbar(zip(plaintext.read_tsv(collection_stream), d5_iter), desc='documents'):
                 yield indices.RawDoc(did, f'{text} {d5text}')
Beispiel #3
0
 def _init_iter_collection(self):
     with util.download_tmp(_SOURCES['collection']) as f:
         with tarfile.open(fileobj=f) as tarf:
             collection_stream = io.TextIOWrapper(
                 tarf.extractfile('collection.tsv'))
             for did, text in self.logger.pbar(
                     plaintext.read_tsv(collection_stream),
                     desc='documents'):
                 yield indices.RawDoc(did, text)
Beispiel #4
0
 def _init_iter_collection(self):
     # Using the trick here from capreolus, pulling document content out of public index:
     # <https://github.com/capreolus-ir/capreolus/blob/d6ae210b24c32ff817f615370a9af37b06d2da89/capreolus/collection/robust04.yaml#L15>
     with util.download_tmp(**_FILES['index']) as f:
         fd = f'{f.name}.d'
         util.extract_tarball(f.name,
                              fd,
                              self.logger,
                              reset_permissions=True)
         index = indices.AnseriniIndex(f'{fd}/index-robust04-20191213')
         for did in self.logger.pbar(index.docids(), desc='documents'):
             raw_doc = index.get_raw(did)
             yield indices.RawDoc(did, raw_doc)
Beispiel #5
0
    def _init_iter_collection(self):
        files = {
            '2020-04-10': {
                'comm_use_subset':
                ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-04-10/comm_use_subset.tar.gz',
                 "253cecb4fee2582a611fb77a4d537dc5"),
                'noncomm_use_subset':
                ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-04-10/noncomm_use_subset.tar.gz',
                 "734b462133b3c00da578a909f945f4ae"),
                'custom_license':
                ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-04-10/custom_license.tar.gz',
                 "2f1c9864348025987523b86d6236c40b"),
                'biorxiv_medrxiv':
                ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-04-10/biorxiv_medrxiv.tar.gz',
                 "c12acdec8b3ad31918d752ba3db36121"),
            },
            '2020-05-01': {
                'comm_use_subset':
                ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-05-01/comm_use_subset.tar.gz',
                 "af4202340182209881d3d8cba2d58a24"),
                'noncomm_use_subset':
                ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-05-01/noncomm_use_subset.tar.gz',
                 "9cc25b9e8674197446e7cbd4381f643b"),
                'custom_license':
                ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-05-01/custom_license.tar.gz',
                 "1cb6936a7300a31344cd8a5ecc9ca778"),
                'biorxiv_medrxiv':
                ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-05-01/biorxiv_medrxiv.tar.gz',
                 "9d6c6dc5d64b01e528086f6652b3ccb7"),
                'arxiv':
                ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-05-01/arxiv.tar.gz',
                 "f10890174d6f864f306800d4b02233bc"),
            }
        }
        metadata = {
            '2020-04-10':
            ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-04-10/metadata.csv',
             "42a21f386be86c24647a41bedde34046"),
            '2020-05-01':
            ('https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/2020-05-01/metadata.csv',
             "b1d2e409026494e0c8034278bacd1248"),
        }
        meta_url, meta_md5 = metadata[self.config['date']]

        fulltexts = {}
        with contextlib.ExitStack() as stack:
            for fid, (file, md5) in files[self.config['date']].items():
                fulltexts[fid] = stack.enter_context(
                    util.download_tmp(file, tarf=True, expected_md5=md5))
            meta = pd.read_csv(
                util.download_stream(meta_url, expected_md5=meta_md5))
            for _, row in meta.iterrows():
                did = str(row['cord_uid'])
                title = str(row['title'])
                doi = str(row['doi'])
                abstract = str(row['abstract'])
                date = str(row['publish_time'])
                body = ''
                heads = ''
                if row['has_pmc_xml_parse']:
                    path = os.path.join(row['full_text_file'], 'pmc_json',
                                        row['pmcid'] + '.xml.json')
                    data = json.load(
                        fulltexts[row['full_text_file']].extractfile(path))
                    if 'body_text' in data:
                        body = '\n'.join(b['text'] for b in data['body_text'])
                        heads = '\n'.join(
                            set(b['section'] for b in data['body_text']))
                elif row['has_pdf_parse']:
                    path = os.path.join(
                        row['full_text_file'], 'pdf_json',
                        row['sha'].split(';')[0].strip() + '.json')
                    data = json.load(
                        fulltexts[row['full_text_file']].extractfile(path))
                    if 'body_text' in data:
                        body = '\n'.join(b['text'] for b in data['body_text'])
                        heads = '\n'.join(
                            set(b['section'] for b in data['body_text']))
                contents = f'{title}\n\n{abstract}\n\n{body}\n\n{heads}'
                doc = indices.RawDoc(did,
                                     text=contents,
                                     title=title,
                                     abstract=abstract,
                                     title_abs=f'{title}\n\n{abstract}',
                                     body=body,
                                     doi=doi,
                                     date=date)
                yield doc
Beispiel #6
0
    def init(self, force=False):
        needs_docs = []
        for index in [self.index_stem, self.index_stem_2020, self.doc_store]:
            if force or not index.built():
                needs_docs.append(index)

        if needs_docs and self._confirm_dua():
            with contextlib.ExitStack() as stack:
                doc_iter = self._init_iter_collection()
                doc_iter = self.logger.pbar(doc_iter, desc='articles')
                doc_iters = util.blocking_tee(doc_iter, len(needs_docs))
                for idx, it in zip(needs_docs, doc_iters):
                    if idx is self.index_stem_2020:
                        it = (d for d in it if '2020' in d.data['date'])
                    stack.enter_context(
                        util.CtxtThread(functools.partial(idx.build, it)))

        path = os.path.join(util.path_dataset(self), 'rnd1.tsv')
        if not os.path.exists(path) and self._confirm_dua():
            with util.download_tmp('https://ir.nist.gov/covidSubmit/data/topics-rnd1.xml', expected_md5="cf1b605222f45f7dbc90ca8e4d9b2c31") as f, \
                 util.finialized_file(path, 'wt') as fout:
                soup = BeautifulSoup(f.read(), 'lxml-xml')
                for topic in soup.find_all('topic'):
                    qid = topic['number']
                    plaintext.write_tsv(fout, [
                        (qid, 'query', topic.find('query').get_text()),
                        (qid, 'quest', topic.find('question').get_text()),
                        (qid, 'narr', topic.find('narrative').get_text()),
                    ])

        udel_flag = path + '.includes_udel'
        if not os.path.exists(udel_flag):
            with open(path,
                      'at') as fout, util.finialized_file(udel_flag, 'wt'):
                with util.download_tmp(
                        'https://raw.githubusercontent.com/castorini/anserini/master/src/main/resources/topics-and-qrels/topics.covid-round1-udel.xml',
                        expected_md5="2915cf59ae222f0aa20b2a671f67fd7a") as f:
                    soup = BeautifulSoup(f.read(), 'lxml-xml')
                    for topic in soup.find_all('topic'):
                        qid = topic['number']
                        plaintext.write_tsv(fout, [
                            (qid, 'udel', topic.find('query').get_text()),
                        ])

        path = os.path.join(util.path_dataset(self), 'rnd2.tsv')
        if not os.path.exists(path) and self._confirm_dua():
            with util.download_tmp('https://ir.nist.gov/covidSubmit/data/topics-rnd2.xml', expected_md5="550129e71c83de3fb4d6d29a172c5842") as f, \
                 util.finialized_file(path, 'wt') as fout:
                soup = BeautifulSoup(f.read(), 'lxml-xml')
                for topic in soup.find_all('topic'):
                    qid = topic['number']
                    plaintext.write_tsv(fout, [
                        (qid, 'query', topic.find('query').get_text()),
                        (qid, 'quest', topic.find('question').get_text()),
                        (qid, 'narr', topic.find('narrative').get_text()),
                    ])

        udel_flag = path + '.includes_udel'
        if not os.path.exists(udel_flag):
            with open(path,
                      'at') as fout, util.finialized_file(udel_flag, 'wt'):
                with util.download_tmp(
                        'https://raw.githubusercontent.com/castorini/anserini/master/src/main/resources/topics-and-qrels/topics.covid-round2-udel.xml',
                        expected_md5="a8988734e6f812921d5125249c197985") as f:
                    soup = BeautifulSoup(f.read(), 'lxml-xml')
                    for topic in soup.find_all('topic'):
                        qid = topic['number']
                        plaintext.write_tsv(fout, [
                            (qid, 'udel', topic.find('query').get_text()),
                        ])

        path = os.path.join(util.path_dataset(self), 'rnd1.qrels')
        if not os.path.exists(path) and self._confirm_dua():
            util.download(
                'https://ir.nist.gov/covidSubmit/data/qrels-rnd1.txt',
                path,
                expected_md5="d58586df5823e7d1d0b3619a73b31518")
Beispiel #7
0
 def _init_doc_iter(self):
     with util.download_tmp(_SOURCES['corpus'], tarf=True) as f:
         cbor_file = f.extract('paragraphcorpus/paragraphcorpus.cbor')
         for did, text in self.logger.pbar(car.iter_paras(cbor_file), desc='documents'):
             yield indices.RawDoc(did, text)
Beispiel #8
0
    def init(self, force=False):
        idxs = [self.index_stem, self.doc_store]
        self._init_indices_parallel(idxs, self._init_iter_collection(), force)

        base_path = util.path_dataset(self)

        needs_queries = []
        if force or not os.path.exists(
                os.path.join(base_path, 'train.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'train.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.train.tsv' and qid not in MINI_DEV)))
        if force or not os.path.exists(
                os.path.join(base_path, 'minidev.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'minidev.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.train.tsv' and qid in MINI_DEV)))
        if force or not os.path.exists(
                os.path.join(base_path, 'dev.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'dev.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.dev.tsv')))
        if force or not os.path.exists(
                os.path.join(base_path, 'eval.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'eval.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.eval.tsv')))

        if needs_queries and self._confirm_dua():
            with util.download_tmp(_SOURCES['queries']) as f, \
                 tarfile.open(fileobj=f) as tarf, \
                 contextlib.ExitStack() as ctxt:

                def _extr_subf(subf):
                    for qid, txt in plaintext.read_tsv(
                            io.TextIOWrapper(tarf.extractfile(subf))):
                        yield subf, qid, txt

                query_iter = [
                    _extr_subf('queries.train.tsv'),
                    _extr_subf('queries.dev.tsv'),
                    _extr_subf('queries.eval.tsv')
                ]
                query_iter = tqdm(itertools.chain(*query_iter), desc='queries')
                query_iters = util.blocking_tee(query_iter, len(needs_queries))
                for fn, it in zip(needs_queries, query_iters):
                    ctxt.enter_context(
                        util.CtxtThread(functools.partial(fn, it)))

        file = os.path.join(base_path, 'train.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['train-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    if qid not in MINI_DEV:
                        trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'minidev.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['train-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    if qid in MINI_DEV:
                        trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'dev.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['dev-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'train.mspairs.gz')
        if not os.path.exists(file) and os.path.exists(
                os.path.join(base_path, 'qidpidtriples.train.full')):
            # legacy
            os.rename(os.path.join(base_path, 'qidpidtriples.train.full'),
                      file)
        if (force or not os.path.exists(file)) and self._confirm_dua():
            util.download(_SOURCES['qidpidtriples.train.full'], file)

        if not self.config['init_skip_msrun']:
            for file_name, subf in [('dev.msrun', 'top1000.dev'),
                                    ('eval.msrun', 'top1000.eval'),
                                    ('train.msrun', 'top1000.train.txt')]:
                file = os.path.join(base_path, file_name)
                if (force or not os.path.exists(file)) and self._confirm_dua():
                    run = {}
                    with util.download_tmp(_SOURCES[file_name]) as f, \
                         tarfile.open(fileobj=f) as tarf:
                        for qid, did, _, _ in tqdm(
                                plaintext.read_tsv(
                                    io.TextIOWrapper(tarf.extractfile(subf)))):
                            if qid not in run:
                                run[qid] = {}
                            run[qid][did] = 0.
                    if file_name == 'train.msrun':
                        minidev = {
                            qid: dids
                            for qid, dids in run.items() if qid in MINI_DEV
                        }
                        with self.logger.duration('writing minidev.msrun'):
                            trec.write_run_dict(
                                os.path.join(base_path, 'minidev.msrun'),
                                minidev)
                        run = {
                            qid: dids
                            for qid, dids in run.items() if qid not in MINI_DEV
                        }
                    with self.logger.duration(f'writing {file_name}'):
                        trec.write_run_dict(file, run)

        query_path = os.path.join(base_path, 'trec2019.queries.tsv')
        if (force or not os.path.exists(query_path)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['trec2019.queries'], 'utf8')
            plaintext.write_tsv(query_path, plaintext.read_tsv(stream))
        msrun_path = os.path.join(base_path, 'trec2019.msrun')
        if (force or not os.path.exists(msrun_path)) and self._confirm_dua():
            run = {}
            with util.download_stream(_SOURCES['trec2019.msrun'],
                                      'utf8') as stream:
                for qid, did, _, _ in plaintext.read_tsv(stream):
                    if qid not in run:
                        run[qid] = {}
                    run[qid][did] = 0.
            with util.finialized_file(msrun_path, 'wt') as f:
                trec.write_run_dict(f, run)

        qrels_path = os.path.join(base_path, 'trec2019.qrels')
        if not os.path.exists(qrels_path) and self._confirm_dua():
            util.download(_SOURCES['trec2019.qrels'], qrels_path)
        qrels_path = os.path.join(base_path, 'judgedtrec2019.qrels')
        if not os.path.exists(qrels_path):
            os.symlink('trec2019.qrels', qrels_path)
        query_path = os.path.join(base_path, 'judgedtrec2019.queries.tsv')
        judged_qids = util.Lazy(
            lambda: trec.read_qrels_dict(qrels_path).keys())
        if (force or not os.path.exists(query_path)):
            with util.finialized_file(query_path, 'wt') as f:
                for qid, qtext in plaintext.read_tsv(
                        os.path.join(base_path, 'trec2019.queries.tsv')):
                    if qid in judged_qids():
                        plaintext.write_tsv(f, [(qid, qtext)])
        msrun_path = os.path.join(base_path, 'judgedtrec2019.msrun')
        if (force or not os.path.exists(msrun_path)) and self._confirm_dua():
            with util.finialized_file(msrun_path, 'wt') as f:
                for qid, dids in trec.read_run_dict(
                        os.path.join(base_path, 'trec2019.msrun')).items():
                    if qid in judged_qids():
                        trec.write_run_dict(f, {qid: dids})

        # A subset of dev that only contains queries that have relevance judgments
        judgeddev_path = os.path.join(base_path, 'judgeddev')
        judged_qids = util.Lazy(lambda: trec.read_qrels_dict(
            os.path.join(base_path, 'dev.qrels')).keys())
        if not os.path.exists(f'{judgeddev_path}.qrels'):
            os.symlink('dev.qrels', f'{judgeddev_path}.qrels')
        if not os.path.exists(f'{judgeddev_path}.queries.tsv'):
            with util.finialized_file(f'{judgeddev_path}.queries.tsv',
                                      'wt') as f:
                for qid, qtext in plaintext.read_tsv(
                        os.path.join(base_path, 'dev.queries.tsv')):
                    if qid in judged_qids():
                        plaintext.write_tsv(f, [(qid, qtext)])
        if self.config['init_skip_msrun']:
            if not os.path.exists(f'{judgeddev_path}.msrun'):
                with util.finialized_file(f'{judgeddev_path}.msrun',
                                          'wt') as f:
                    for qid, dids in trec.read_run_dict(
                            os.path.join(base_path, 'dev.msrun')).items():
                        if qid in judged_qids():
                            trec.write_run_dict(f, {qid: dids})

        if not self.config['init_skip_train10']:
            file = os.path.join(base_path, 'train10.queries.tsv')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout:
                    for qid, qtext in self.logger.pbar(
                            plaintext.read_tsv(
                                os.path.join(base_path, 'train.queries.tsv')),
                            desc='filtering queries for train10'):
                        if int(qid) % 10 == 0:
                            plaintext.write_tsv(fout, [(qid, qtext)])

            file = os.path.join(base_path, 'train10.qrels')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout, open(
                        os.path.join(base_path, 'train.qrels'), 'rt') as fin:
                    for line in self.logger.pbar(
                            fin, desc='filtering qrels for train10'):
                        qid = line.split()[0]
                        if int(qid) % 10 == 0:
                            fout.write(line)

            if not self.config['init_skip_msrun']:
                file = os.path.join(base_path, 'train10.msrun')
                if not os.path.exists(file):
                    with util.finialized_file(file, 'wt') as fout, open(
                            os.path.join(base_path, 'train.msrun'),
                            'rt') as fin:
                        for line in self.logger.pbar(
                                fin, desc='filtering msrun for train10'):
                            qid = line.split()[0]
                            if int(qid) % 10 == 0:
                                fout.write(line)

            file = os.path.join(base_path, 'train10.mspairs.gz')
            if not os.path.exists(file):
                with gzip.open(file, 'wt') as fout, gzip.open(
                        os.path.join(base_path, 'train.mspairs.gz'),
                        'rt') as fin:
                    for qid, did1, did2 in self.logger.pbar(
                            plaintext.read_tsv(fin),
                            desc='filtering mspairs for train10'):
                        if int(qid) % 10 == 0:
                            plaintext.write_tsv(fout, [(qid, did1, did2)])

        if not self.config['init_skip_train_med']:
            med_qids = util.Lazy(
                lambda: {
                    qid.strip()
                    for qid in util.download_stream(
                        'https://raw.githubusercontent.com/Georgetown-IR-Lab/covid-neural-ir/master/med-msmarco-train.txt',
                        'utf8',
                        expected_md5="dc5199de7d4a872c361f89f08b1163ef")
                })
            file = os.path.join(base_path, 'train_med.queries.tsv')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout:
                    for qid, qtext in self.logger.pbar(
                            plaintext.read_tsv(
                                os.path.join(base_path, 'train.queries.tsv')),
                            desc='filtering queries for train_med'):
                        if qid in med_qids():
                            plaintext.write_tsv(fout, [(qid, qtext)])

            file = os.path.join(base_path, 'train_med.qrels')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout, open(
                        os.path.join(base_path, 'train.qrels'), 'rt') as fin:
                    for line in self.logger.pbar(
                            fin, desc='filtering qrels for train_med'):
                        qid = line.split()[0]
                        if qid in med_qids():
                            fout.write(line)

            if not self.config['init_skip_msrun']:
                file = os.path.join(base_path, 'train_med.msrun')
                if not os.path.exists(file):
                    with util.finialized_file(file, 'wt') as fout, open(
                            os.path.join(base_path, 'train.msrun'),
                            'rt') as fin:
                        for line in self.logger.pbar(
                                fin, desc='filtering msrun for train_med'):
                            qid = line.split()[0]
                            if qid in med_qids():
                                fout.write(line)

            file = os.path.join(base_path, 'train_med.mspairs.gz')
            if not os.path.exists(file):
                with gzip.open(file, 'wt') as fout, gzip.open(
                        os.path.join(base_path, 'train.mspairs.gz'),
                        'rt') as fin:
                    for qid, did1, did2 in self.logger.pbar(
                            plaintext.read_tsv(fin),
                            desc='filtering mspairs for train_med'):
                        if qid in med_qids():
                            plaintext.write_tsv(fout, [(qid, did1, did2)])
Beispiel #9
0
    def _init_collection(self, collection, force=False):
        base_path = util.path_dataset(self)
        if collection == '1k':
            idxs = [self.index1k, self.index1k_stem, self.docstore1k]
        elif collection == '59k':
            idxs = [self.index59k, self.index59k_stem, self.docstore59k]
        else:
            raise ValueError(f'unsupported collection {collection}')

        query_files = {
            f'wikIR{collection}/training/queries.csv':
            os.path.join(base_path, f'train.{collection}.queries'),
            f'wikIR{collection}/validation/queries.csv':
            os.path.join(base_path, f'dev.{collection}.queries'),
            f'wikIR{collection}/test/queries.csv':
            os.path.join(base_path, f'test.{collection}.queries')
        }

        qrels_files = {
            f'wikIR{collection}/training/qrels':
            os.path.join(base_path, f'train.{collection}.qrels'),
            f'wikIR{collection}/validation/qrels':
            os.path.join(base_path, f'dev.{collection}.qrels'),
            f'wikIR{collection}/test/qrels':
            os.path.join(base_path, f'test.{collection}.qrels')
        }

        theirbm25_files = {
            f'wikIR{collection}/training/BM25.res':
            os.path.join(base_path, f'train.{collection}.theirbm25'),
            f'wikIR{collection}/validation/BM25.res':
            os.path.join(base_path, f'dev.{collection}.theirbm25'),
            f'wikIR{collection}/test/BM25.res':
            os.path.join(base_path, f'test.{collection}.theirbm25')
        }

        if not force and \
           all(i.built() for i in idxs) and \
           all(os.path.exists(f) for f in query_files.values()) and \
           all(os.path.exists(f) for f in qrels_files.values()) and \
           all(os.path.exists(f) for f in theirbm25_files.values()):
            return

        if not self._confirm_dua():
            return

        with util.download_tmp(_SOURCES[collection]) as f:
            with zipfile.ZipFile(f) as zipf:
                doc_iter = self._init_iter_collection(zipf, collection)
                self._init_indices_parallel(idxs, doc_iter, force)

                for zqueryf, queryf in query_files.items():
                    if force or not os.path.exists(queryf):
                        with zipf.open(zqueryf) as f, open(queryf,
                                                           'wt') as out:
                            f = io.TextIOWrapper(f)
                            f.readline()  # head
                            for qid, text in plaintext.read_sv(f, ','):
                                plaintext.write_tsv(out, [[qid, text]])

                for zqrelf, qrelf in qrels_files.items():
                    if force or not os.path.exists(qrelf):
                        with zipf.open(zqrelf) as f, open(qrelf, 'wt') as out:
                            f = io.TextIOWrapper(f)
                            plaintext.write_sv(out, plaintext.read_tsv(f), ' ')

                for zbm25, bm25 in theirbm25_files.items():
                    if force or not os.path.exists(bm25):
                        with zipf.open(zbm25) as f, open(bm25, 'wb') as out:
                            out.write(f.read())