Пример #1
0
    def init(self, force=False):
        base_path = util.path_dataset(self)
        idxs = [self.index, self.index_stem, self.doc_store]
        self._init_indices_parallel(idxs, self._init_iter_collection(), force)

        qrels_file = os.path.join(base_path, 'qrels.robust2004.txt')
        if (force or not os.path.exists(qrels_file)) and self._confirm_dua():
            util.download(**_FILES['qrels'], file_name=qrels_file)

        for fold in FOLDS:
            fold_qrels_file = os.path.join(base_path, f'{fold}.qrels')
            if (force or not os.path.exists(fold_qrels_file)):
                all_qrels = trec.read_qrels_dict(qrels_file)
                fold_qrels = {
                    qid: dids
                    for qid, dids in all_qrels.items() if qid in FOLDS[fold]
                }
                trec.write_qrels_dict(fold_qrels_file, fold_qrels)

        query_file = os.path.join(base_path, 'topics.txt')
        if (force or not os.path.exists(query_file)) and self._confirm_dua():
            query_file_stream = util.download_stream(**_FILES['queries'],
                                                     encoding='utf8')
            with util.finialized_file(query_file, 'wt') as f:
                plaintext.write_tsv(f,
                                    trec.parse_query_format(query_file_stream))
Пример #2
0
 def topics_path(self, path_topics):
     # Save the topics
     if not path_topics.is_file():
         with util.finialized_file(path_topics,
                                   'wt') as f, self.topics.path.open(
                                       "rt") as query_file_stream:
             data = ((item, qid, text) for item, qid, text in
                     trec.parse_query_format(query_file_stream)
                     if qid in self.qids)
             plaintext.write_tsv(f, data)
     return path_topics
Пример #3
0
 def write_missing_metrics(self, ctxt, missing_metrics):
     epoch = ctxt['epoch']
     for metric in missing_metrics:
         os.makedirs(os.path.join(ctxt['base_path'], metric), exist_ok=True)
         path_agg = os.path.join(ctxt['base_path'], metric, 'agg.txt')
         path_epoch = os.path.join(ctxt['base_path'], metric,
                                   f'{epoch}.txt')
         with open(path_agg, 'at') as f:
             plaintext.write_tsv(
                 f, [(str(epoch), str(ctxt['metrics'][metric]))])
         plaintext.write_tsv(path_epoch,
                             ctxt['metrics_by_query'][metric].items())
Пример #4
0
    def init(self, force=False):
        base_path = util.path_dataset(self)
        idxs = [self.index, self.index_stem, self.doc_store]
        self._init_indices_parallel(idxs, self._init_iter_collection(), force)
        train_qrels = os.path.join(base_path, 'train.qrels.txt')
        valid_qrels = os.path.join(base_path, 'valid.qrels.txt')
        test_qrels = os.path.join(base_path, 'test.qrels.txt')

        if (force or not os.path.exists(train_qrels)
                or not os.path.exists(valid_qrels)) and self._confirm_dua():
            source_stream = util.download_stream(**_FILES['qrels_2013'],
                                                 encoding='utf8')
            source_stream2 = util.download_stream(**_FILES['qrels_2014'],
                                                  encoding='utf8')
            with util.finialized_file(train_qrels, 'wt') as tf, \
                 util.finialized_file(valid_qrels, 'wt') as vf, \
                 util.finialized_file(test_qrels, 'wt') as Tf:
                for line in source_stream:
                    cols = line.strip().split()
                    if int(cols[0]) in VALIDATION_QIDS:
                        vf.write(' '.join(cols) + '\n')
                    elif int(cols[0]) in TEST_QIDS:
                        Tf.write(' '.join(cols) + '\n')
                    else:
                        tf.write(' '.join(cols) + '\n')
                for line in source_stream2:
                    cols = line.strip().split()
                    if cols[0] in VALIDATION_QIDS:
                        vf.write(' '.join(cols) + '\n')
                    elif int(cols[0]) in TEST_QIDS:
                        Tf.write(' '.join(cols) + '\n')
                    else:
                        tf.write(' '.join(cols) + '\n')

        all_queries = os.path.join(base_path, 'topics.txt')

        if (force or not os.path.exists(all_queries)) and self._confirm_dua():
            source_stream = util.download_stream(**_FILES['queries_2013'],
                                                 encoding='utf8')
            source_stream2 = util.download_stream(**_FILES['queries_2014'],
                                                  encoding='utf8')
            train, valid = [], []
            for _id, _query in trec.parse_query_mbformat(source_stream):
                nid = _id.replace('MB', '').strip()
                train.append([nid, _query])

            for _id, _query in trec.parse_query_mbformat(source_stream2):
                nid = _id.replace('MB', '').strip()
                train.append([nid, _query])

            plaintext.write_tsv(all_queries, train)
Пример #5
0
    def init(self, force=False):
        idxs = [self.index, self.index_stem, self.doc_store]
        self._init_indices_parallel(idxs, self._init_iter_collection(), force)

        train_qrels = os.path.join(util.path_dataset(self), 'train.qrels.txt')
        valid_qrels = os.path.join(util.path_dataset(self), 'valid.qrels.txt')
        if (force or not os.path.exists(train_qrels)
                or not os.path.exists(valid_qrels)) and self._confirm_dua():
            source_stream = util.download_stream(
                'https://ciir.cs.umass.edu/downloads/Antique/antique-train.qrel',
                encoding='utf8')
            with util.finialized_file(train_qrels, 'wt') as tf, \
                 util.finialized_file(valid_qrels, 'wt') as vf:
                for line in source_stream:
                    cols = line.strip().split()
                    if cols[0] in VALIDATION_QIDS:
                        vf.write(' '.join(cols) + '\n')
                    else:
                        tf.write(' '.join(cols) + '\n')

        train_queries = os.path.join(util.path_dataset(self),
                                     'train.queries.txt')
        valid_queries = os.path.join(util.path_dataset(self),
                                     'valid.queries.txt')
        if (force or not os.path.exists(train_queries)
                or not os.path.exists(valid_queries)) and self._confirm_dua():
            source_stream = util.download_stream(
                'https://ciir.cs.umass.edu/downloads/Antique/antique-train-queries.txt',
                encoding='utf8')
            train, valid = [], []
            for cols in plaintext.read_tsv(source_stream):
                if cols[0] in VALIDATION_QIDS:
                    valid.append(cols)
                else:
                    train.append(cols)
            plaintext.write_tsv(train_queries, train)
            plaintext.write_tsv(valid_queries, valid)

        test_qrels = os.path.join(util.path_dataset(self), 'test.qrels.txt')
        if (force or not os.path.exists(test_qrels)) and self._confirm_dua():
            util.download(
                'https://ciir.cs.umass.edu/downloads/Antique/antique-test.qrel',
                test_qrels)

        test_queries = os.path.join(util.path_dataset(self),
                                    'test.queries.txt')
        if (force or not os.path.exists(test_queries)) and self._confirm_dua():
            util.download(
                'https://ciir.cs.umass.edu/downloads/Antique/antique-test-queries.txt',
                test_queries)
Пример #6
0
    def init(self, force=False):
        base_dir = os.path.join(util.path_dataset(self), self.subset)

        if self.subset == 'dummy':
            datafile = os.path.join(base_dir, 'datafile.tsv')
            qrels = os.path.join(base_dir, 'qrels.txt')
            if not os.path.exists(datafile):
                os.symlink(os.path.abspath('etc/dummy_datafile.tsv'), datafile)
            if not os.path.exists(qrels):
                os.symlink(os.path.abspath('etc/dummy_qrels.txt'), qrels)

        needs_datafile = []
        if force or not self.index.built():
            needs_datafile.append(lambda it: self.index.build(
                indices.RawDoc(did, txt) for t, did, txt in it if t == 'doc'))

        if force or not self.index_stem.built():
            needs_datafile.append(lambda it: self.index_stem.build(
                indices.RawDoc(did, txt) for t, did, txt in it if t == 'doc'))

        if force or not self.doc_store.built():
            needs_datafile.append(lambda it: self.doc_store.build(
                indices.RawDoc(did, txt) for t, did, txt in it if t == 'doc'))

        query_file = os.path.join(base_dir, 'queries.tsv')
        if force or not os.path.exists(query_file):
            needs_datafile.append(
                lambda it: plaintext.write_tsv(query_file, (
                    (qid, txt) for t, qid, txt in it if t == 'query')))

        if needs_datafile:
            df_glob = os.path.join(base_dir, 'datafile*.tsv')
            datafiles = glob(df_glob)
            while not datafiles:
                c = util.confirm(
                    f'No data files found. Please move/link data files to {df_glob}.\n'
                    'Data files should contain both queries and documents in the '
                    'following format (one per line):\n'
                    '[query|doc] [TAB] [qid/did] [TAB] [text]')
                if not c:
                    sys.exit(1)
                datafiles = glob(df_glob)
            main_iter = itertools.chain(*(plaintext.read_tsv(df)
                                          for df in datafiles))
            main_iter = tqdm(main_iter, desc='reading datafiles')
            iters = util.blocking_tee(main_iter, len(needs_datafile))
            with contextlib.ExitStack() as stack:
                for fn, it in zip(needs_datafile, iters):
                    stack.enter_context(
                        util.CtxtThread(functools.partial(fn, it)))

        qrels_file = os.path.join(base_dir, 'qrels.txt')
        while not os.path.exists(qrels_file):
            c = util.confirm(
                f'No qrels file found. Please move/link qrels file to {qrels_file}.\n'
                'Qrels file should be in the TREC format:\n'
                '[qid] [SPACE] Q0 [SPACE] [did] [SPACE] [score]')
            if not c:
                sys.exit(1)
Пример #7
0
 def _init_topics(self,
                  subset,
                  topic_files,
                  qid_prefix=None,
                  encoding=None,
                  xml_prefix=None,
                  force=False,
                  expected_md5=None):
     topicf = os.path.join(util.path_dataset(self), f'{subset}.topics')
     if (force or not os.path.exists(topicf)) and self._confirm_dua():
         topics = []
         for topic_file in topic_files:
             topic_file_stream = util.download_stream(
                 topic_file, encoding, expected_md5=expected_md5)
             for t, qid, text in trec.parse_query_format(
                     topic_file_stream, xml_prefix):
                 if qid_prefix is not None:
                     qid = qid.replace(qid_prefix, '')
                 topics.append((t, qid, text))
         plaintext.write_tsv(topicf, topics)
Пример #8
0
    def init(self, force=False):
        needs_docs = []
        for index in [self.index_stem, self.index_stem_2020, self.doc_store]:
            if force or not index.built():
                needs_docs.append(index)

        if needs_docs and self._confirm_dua():
            with contextlib.ExitStack() as stack:
                doc_iter = self._init_iter_collection()
                doc_iter = self.logger.pbar(doc_iter, desc='articles')
                doc_iters = util.blocking_tee(doc_iter, len(needs_docs))
                for idx, it in zip(needs_docs, doc_iters):
                    if idx is self.index_stem_2020:
                        it = (d for d in it if '2020' in d.data['date'])
                    stack.enter_context(
                        util.CtxtThread(functools.partial(idx.build, it)))

        path = os.path.join(util.path_dataset(self), 'rnd1.tsv')
        if not os.path.exists(path) and self._confirm_dua():
            with util.download_tmp('https://ir.nist.gov/covidSubmit/data/topics-rnd1.xml', expected_md5="cf1b605222f45f7dbc90ca8e4d9b2c31") as f, \
                 util.finialized_file(path, 'wt') as fout:
                soup = BeautifulSoup(f.read(), 'lxml-xml')
                for topic in soup.find_all('topic'):
                    qid = topic['number']
                    plaintext.write_tsv(fout, [
                        (qid, 'query', topic.find('query').get_text()),
                        (qid, 'quest', topic.find('question').get_text()),
                        (qid, 'narr', topic.find('narrative').get_text()),
                    ])

        udel_flag = path + '.includes_udel'
        if not os.path.exists(udel_flag):
            with open(path,
                      'at') as fout, util.finialized_file(udel_flag, 'wt'):
                with util.download_tmp(
                        'https://raw.githubusercontent.com/castorini/anserini/master/src/main/resources/topics-and-qrels/topics.covid-round1-udel.xml',
                        expected_md5="2915cf59ae222f0aa20b2a671f67fd7a") as f:
                    soup = BeautifulSoup(f.read(), 'lxml-xml')
                    for topic in soup.find_all('topic'):
                        qid = topic['number']
                        plaintext.write_tsv(fout, [
                            (qid, 'udel', topic.find('query').get_text()),
                        ])

        path = os.path.join(util.path_dataset(self), 'rnd2.tsv')
        if not os.path.exists(path) and self._confirm_dua():
            with util.download_tmp('https://ir.nist.gov/covidSubmit/data/topics-rnd2.xml', expected_md5="550129e71c83de3fb4d6d29a172c5842") as f, \
                 util.finialized_file(path, 'wt') as fout:
                soup = BeautifulSoup(f.read(), 'lxml-xml')
                for topic in soup.find_all('topic'):
                    qid = topic['number']
                    plaintext.write_tsv(fout, [
                        (qid, 'query', topic.find('query').get_text()),
                        (qid, 'quest', topic.find('question').get_text()),
                        (qid, 'narr', topic.find('narrative').get_text()),
                    ])

        udel_flag = path + '.includes_udel'
        if not os.path.exists(udel_flag):
            with open(path,
                      'at') as fout, util.finialized_file(udel_flag, 'wt'):
                with util.download_tmp(
                        'https://raw.githubusercontent.com/castorini/anserini/master/src/main/resources/topics-and-qrels/topics.covid-round2-udel.xml',
                        expected_md5="a8988734e6f812921d5125249c197985") as f:
                    soup = BeautifulSoup(f.read(), 'lxml-xml')
                    for topic in soup.find_all('topic'):
                        qid = topic['number']
                        plaintext.write_tsv(fout, [
                            (qid, 'udel', topic.find('query').get_text()),
                        ])

        path = os.path.join(util.path_dataset(self), 'rnd1.qrels')
        if not os.path.exists(path) and self._confirm_dua():
            util.download(
                'https://ir.nist.gov/covidSubmit/data/qrels-rnd1.txt',
                path,
                expected_md5="d58586df5823e7d1d0b3619a73b31518")
Пример #9
0
 def __setitem__(self, epoch, value):
     self.content[epoch] = float(value)
     with open(self.path, 'at') as f:
         plaintext.write_tsv(f, [(str(epoch), str(value))])
Пример #10
0
 def wrapped(it):
     with util.finialized_file(file, 'wt') as f:
         for doc in it:
             if is_heldout == (doc.did in _HELD_OUT_IDS):
                 plaintext.write_tsv(f, [(doc.did, doc.data['headline'])])
Пример #11
0
    def __call__(self, ctxt):
        cached = True
        epoch = ctxt['epoch']
        base_path = os.path.join(ctxt['base_path'],
                                 self.pred.dataset.path_segment())
        if self.pred.config[
                'source'] == 'run' and self.pred.config['run_threshold'] > 0:
            base_path = '{p}_runthreshold-{run_threshold}'.format(
                p=base_path, **self.pred.config)
        os.makedirs(os.path.join(base_path, 'runs'), exist_ok=True)
        with open(os.path.join(base_path, 'config.json'), 'wt') as f:
            json.dump(self.pred.dataset.config, f)
        run_path = os.path.join(base_path, 'runs', f'{epoch}.run')
        if os.path.exists(run_path):
            run = trec.read_run_dict(run_path)
        else:
            if self.pred.config['source'] == 'run' and self.pred.config[
                    'run_threshold'] > 0:
                official_run = self.pred.dataset.run('dict')
            else:
                official_run = {}
            run = {}
            ranker = ctxt['ranker']().to(self.device)
            this_qid = None
            these_docs = {}
            with util.finialized_file(run_path, 'wt') as f:
                for qid, did, score in self.pred.iter_scores(
                        ranker, self.datasource, self.device):
                    if qid != this_qid:
                        if this_qid is not None:
                            these_docs = self._apply_threshold(
                                these_docs, official_run.get(this_qid, {}))
                            trec.write_run_dict(f, {this_qid: these_docs})
                        this_qid = qid
                        these_docs = {}
                    these_docs[did] = score
                if this_qid is not None:
                    these_docs = self._apply_threshold(
                        these_docs, official_run.get(this_qid, {}))
                    trec.write_run_dict(f, {this_qid: these_docs})
            cached = False

        result = {
            'epoch': epoch,
            'run': run,
            'run_path': run_path,
            'base_path': base_path,
            'cached': cached
        }

        result['metrics'] = {
            m: None
            for m in self.pred.config['measures'].split(',') if m
        }
        result['metrics_by_query'] = {m: None for m in result['metrics']}

        missing_metrics = self.load_metrics(result)

        if missing_metrics:
            measures = set(missing_metrics)
            result['cached'] = False
            qrels = self.pred.dataset.qrels()
            calculated_metrics = onir.metrics.calc(qrels, run_path, measures)
            result['metrics_by_query'].update(calculated_metrics)
            result['metrics'].update(onir.metrics.mean(calculated_metrics))
            self.write_missing_metrics(result, missing_metrics)

        try:
            if ctxt['ranker']().config.get('add_runscore'):
                result['metrics']['runscore_alpha'] = torch.sigmoid(
                    ctxt['ranker']().runscore_alpha).item()
                rs_alpha_f = os.path.join(ctxt['base_path'],
                                          'runscore_alpha.txt')
                with open(rs_alpha_f, 'at') as f:
                    plaintext.write_tsv(rs_alpha_f, [
                        (str(epoch), str(result['metrics']['runscore_alpha']))
                    ])
        except FileNotFoundError:
            pass  # model may no longer exist, ignore

        return result
Пример #12
0
 def _init_queryfile(self, in_stream, out_path, force=False):
     if force or not os.path.exists(out_path):
         with util.finialized_file(out_path, 'wt') as out:
             with self.logger.duration(f'extracting to {out_path}'):
                 for qid, headings in car.iter_queries(in_stream):
                     plaintext.write_tsv(out, [(qid, *headings)])
Пример #13
0
    def init(self, force=False):
        idxs = [self.index_stem, self.doc_store]
        self._init_indices_parallel(idxs, self._init_iter_collection(), force)

        base_path = util.path_dataset(self)

        needs_queries = []
        if force or not os.path.exists(
                os.path.join(base_path, 'train.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'train.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.train.tsv' and qid not in MINI_DEV)))
        if force or not os.path.exists(
                os.path.join(base_path, 'minidev.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'minidev.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.train.tsv' and qid in MINI_DEV)))
        if force or not os.path.exists(
                os.path.join(base_path, 'dev.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'dev.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.dev.tsv')))
        if force or not os.path.exists(
                os.path.join(base_path, 'eval.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'eval.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.eval.tsv')))

        if needs_queries and self._confirm_dua():
            with util.download_tmp(_SOURCES['queries']) as f, \
                 tarfile.open(fileobj=f) as tarf, \
                 contextlib.ExitStack() as ctxt:

                def _extr_subf(subf):
                    for qid, txt in plaintext.read_tsv(
                            io.TextIOWrapper(tarf.extractfile(subf))):
                        yield subf, qid, txt

                query_iter = [
                    _extr_subf('queries.train.tsv'),
                    _extr_subf('queries.dev.tsv'),
                    _extr_subf('queries.eval.tsv')
                ]
                query_iter = tqdm(itertools.chain(*query_iter), desc='queries')
                query_iters = util.blocking_tee(query_iter, len(needs_queries))
                for fn, it in zip(needs_queries, query_iters):
                    ctxt.enter_context(
                        util.CtxtThread(functools.partial(fn, it)))

        file = os.path.join(base_path, 'train.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['train-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    if qid not in MINI_DEV:
                        trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'minidev.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['train-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    if qid in MINI_DEV:
                        trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'dev.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['dev-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'train.mspairs.gz')
        if not os.path.exists(file) and os.path.exists(
                os.path.join(base_path, 'qidpidtriples.train.full')):
            # legacy
            os.rename(os.path.join(base_path, 'qidpidtriples.train.full'),
                      file)
        if (force or not os.path.exists(file)) and self._confirm_dua():
            util.download(_SOURCES['qidpidtriples.train.full'], file)

        if not self.config['init_skip_msrun']:
            for file_name, subf in [('dev.msrun', 'top1000.dev'),
                                    ('eval.msrun', 'top1000.eval'),
                                    ('train.msrun', 'top1000.train.txt')]:
                file = os.path.join(base_path, file_name)
                if (force or not os.path.exists(file)) and self._confirm_dua():
                    run = {}
                    with util.download_tmp(_SOURCES[file_name]) as f, \
                         tarfile.open(fileobj=f) as tarf:
                        for qid, did, _, _ in tqdm(
                                plaintext.read_tsv(
                                    io.TextIOWrapper(tarf.extractfile(subf)))):
                            if qid not in run:
                                run[qid] = {}
                            run[qid][did] = 0.
                    if file_name == 'train.msrun':
                        minidev = {
                            qid: dids
                            for qid, dids in run.items() if qid in MINI_DEV
                        }
                        with self.logger.duration('writing minidev.msrun'):
                            trec.write_run_dict(
                                os.path.join(base_path, 'minidev.msrun'),
                                minidev)
                        run = {
                            qid: dids
                            for qid, dids in run.items() if qid not in MINI_DEV
                        }
                    with self.logger.duration(f'writing {file_name}'):
                        trec.write_run_dict(file, run)

        query_path = os.path.join(base_path, 'trec2019.queries.tsv')
        if (force or not os.path.exists(query_path)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['trec2019.queries'], 'utf8')
            plaintext.write_tsv(query_path, plaintext.read_tsv(stream))
        msrun_path = os.path.join(base_path, 'trec2019.msrun')
        if (force or not os.path.exists(msrun_path)) and self._confirm_dua():
            run = {}
            with util.download_stream(_SOURCES['trec2019.msrun'],
                                      'utf8') as stream:
                for qid, did, _, _ in plaintext.read_tsv(stream):
                    if qid not in run:
                        run[qid] = {}
                    run[qid][did] = 0.
            with util.finialized_file(msrun_path, 'wt') as f:
                trec.write_run_dict(f, run)

        qrels_path = os.path.join(base_path, 'trec2019.qrels')
        if not os.path.exists(qrels_path) and self._confirm_dua():
            util.download(_SOURCES['trec2019.qrels'], qrels_path)
        qrels_path = os.path.join(base_path, 'judgedtrec2019.qrels')
        if not os.path.exists(qrels_path):
            os.symlink('trec2019.qrels', qrels_path)
        query_path = os.path.join(base_path, 'judgedtrec2019.queries.tsv')
        judged_qids = util.Lazy(
            lambda: trec.read_qrels_dict(qrels_path).keys())
        if (force or not os.path.exists(query_path)):
            with util.finialized_file(query_path, 'wt') as f:
                for qid, qtext in plaintext.read_tsv(
                        os.path.join(base_path, 'trec2019.queries.tsv')):
                    if qid in judged_qids():
                        plaintext.write_tsv(f, [(qid, qtext)])
        msrun_path = os.path.join(base_path, 'judgedtrec2019.msrun')
        if (force or not os.path.exists(msrun_path)) and self._confirm_dua():
            with util.finialized_file(msrun_path, 'wt') as f:
                for qid, dids in trec.read_run_dict(
                        os.path.join(base_path, 'trec2019.msrun')).items():
                    if qid in judged_qids():
                        trec.write_run_dict(f, {qid: dids})

        # A subset of dev that only contains queries that have relevance judgments
        judgeddev_path = os.path.join(base_path, 'judgeddev')
        judged_qids = util.Lazy(lambda: trec.read_qrels_dict(
            os.path.join(base_path, 'dev.qrels')).keys())
        if not os.path.exists(f'{judgeddev_path}.qrels'):
            os.symlink('dev.qrels', f'{judgeddev_path}.qrels')
        if not os.path.exists(f'{judgeddev_path}.queries.tsv'):
            with util.finialized_file(f'{judgeddev_path}.queries.tsv',
                                      'wt') as f:
                for qid, qtext in plaintext.read_tsv(
                        os.path.join(base_path, 'dev.queries.tsv')):
                    if qid in judged_qids():
                        plaintext.write_tsv(f, [(qid, qtext)])
        if self.config['init_skip_msrun']:
            if not os.path.exists(f'{judgeddev_path}.msrun'):
                with util.finialized_file(f'{judgeddev_path}.msrun',
                                          'wt') as f:
                    for qid, dids in trec.read_run_dict(
                            os.path.join(base_path, 'dev.msrun')).items():
                        if qid in judged_qids():
                            trec.write_run_dict(f, {qid: dids})

        if not self.config['init_skip_train10']:
            file = os.path.join(base_path, 'train10.queries.tsv')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout:
                    for qid, qtext in self.logger.pbar(
                            plaintext.read_tsv(
                                os.path.join(base_path, 'train.queries.tsv')),
                            desc='filtering queries for train10'):
                        if int(qid) % 10 == 0:
                            plaintext.write_tsv(fout, [(qid, qtext)])

            file = os.path.join(base_path, 'train10.qrels')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout, open(
                        os.path.join(base_path, 'train.qrels'), 'rt') as fin:
                    for line in self.logger.pbar(
                            fin, desc='filtering qrels for train10'):
                        qid = line.split()[0]
                        if int(qid) % 10 == 0:
                            fout.write(line)

            if not self.config['init_skip_msrun']:
                file = os.path.join(base_path, 'train10.msrun')
                if not os.path.exists(file):
                    with util.finialized_file(file, 'wt') as fout, open(
                            os.path.join(base_path, 'train.msrun'),
                            'rt') as fin:
                        for line in self.logger.pbar(
                                fin, desc='filtering msrun for train10'):
                            qid = line.split()[0]
                            if int(qid) % 10 == 0:
                                fout.write(line)

            file = os.path.join(base_path, 'train10.mspairs.gz')
            if not os.path.exists(file):
                with gzip.open(file, 'wt') as fout, gzip.open(
                        os.path.join(base_path, 'train.mspairs.gz'),
                        'rt') as fin:
                    for qid, did1, did2 in self.logger.pbar(
                            plaintext.read_tsv(fin),
                            desc='filtering mspairs for train10'):
                        if int(qid) % 10 == 0:
                            plaintext.write_tsv(fout, [(qid, did1, did2)])

        if not self.config['init_skip_train_med']:
            med_qids = util.Lazy(
                lambda: {
                    qid.strip()
                    for qid in util.download_stream(
                        'https://raw.githubusercontent.com/Georgetown-IR-Lab/covid-neural-ir/master/med-msmarco-train.txt',
                        'utf8',
                        expected_md5="dc5199de7d4a872c361f89f08b1163ef")
                })
            file = os.path.join(base_path, 'train_med.queries.tsv')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout:
                    for qid, qtext in self.logger.pbar(
                            plaintext.read_tsv(
                                os.path.join(base_path, 'train.queries.tsv')),
                            desc='filtering queries for train_med'):
                        if qid in med_qids():
                            plaintext.write_tsv(fout, [(qid, qtext)])

            file = os.path.join(base_path, 'train_med.qrels')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout, open(
                        os.path.join(base_path, 'train.qrels'), 'rt') as fin:
                    for line in self.logger.pbar(
                            fin, desc='filtering qrels for train_med'):
                        qid = line.split()[0]
                        if qid in med_qids():
                            fout.write(line)

            if not self.config['init_skip_msrun']:
                file = os.path.join(base_path, 'train_med.msrun')
                if not os.path.exists(file):
                    with util.finialized_file(file, 'wt') as fout, open(
                            os.path.join(base_path, 'train.msrun'),
                            'rt') as fin:
                        for line in self.logger.pbar(
                                fin, desc='filtering msrun for train_med'):
                            qid = line.split()[0]
                            if qid in med_qids():
                                fout.write(line)

            file = os.path.join(base_path, 'train_med.mspairs.gz')
            if not os.path.exists(file):
                with gzip.open(file, 'wt') as fout, gzip.open(
                        os.path.join(base_path, 'train.mspairs.gz'),
                        'rt') as fin:
                    for qid, did1, did2 in self.logger.pbar(
                            plaintext.read_tsv(fin),
                            desc='filtering mspairs for train_med'):
                        if qid in med_qids():
                            plaintext.write_tsv(fout, [(qid, did1, did2)])
Пример #14
0
    def _init_collection(self, collection, force=False):
        base_path = util.path_dataset(self)
        if collection == '1k':
            idxs = [self.index1k, self.index1k_stem, self.docstore1k]
        elif collection == '59k':
            idxs = [self.index59k, self.index59k_stem, self.docstore59k]
        else:
            raise ValueError(f'unsupported collection {collection}')

        query_files = {
            f'wikIR{collection}/training/queries.csv':
            os.path.join(base_path, f'train.{collection}.queries'),
            f'wikIR{collection}/validation/queries.csv':
            os.path.join(base_path, f'dev.{collection}.queries'),
            f'wikIR{collection}/test/queries.csv':
            os.path.join(base_path, f'test.{collection}.queries')
        }

        qrels_files = {
            f'wikIR{collection}/training/qrels':
            os.path.join(base_path, f'train.{collection}.qrels'),
            f'wikIR{collection}/validation/qrels':
            os.path.join(base_path, f'dev.{collection}.qrels'),
            f'wikIR{collection}/test/qrels':
            os.path.join(base_path, f'test.{collection}.qrels')
        }

        theirbm25_files = {
            f'wikIR{collection}/training/BM25.res':
            os.path.join(base_path, f'train.{collection}.theirbm25'),
            f'wikIR{collection}/validation/BM25.res':
            os.path.join(base_path, f'dev.{collection}.theirbm25'),
            f'wikIR{collection}/test/BM25.res':
            os.path.join(base_path, f'test.{collection}.theirbm25')
        }

        if not force and \
           all(i.built() for i in idxs) and \
           all(os.path.exists(f) for f in query_files.values()) and \
           all(os.path.exists(f) for f in qrels_files.values()) and \
           all(os.path.exists(f) for f in theirbm25_files.values()):
            return

        if not self._confirm_dua():
            return

        with util.download_tmp(_SOURCES[collection]) as f:
            with zipfile.ZipFile(f) as zipf:
                doc_iter = self._init_iter_collection(zipf, collection)
                self._init_indices_parallel(idxs, doc_iter, force)

                for zqueryf, queryf in query_files.items():
                    if force or not os.path.exists(queryf):
                        with zipf.open(zqueryf) as f, open(queryf,
                                                           'wt') as out:
                            f = io.TextIOWrapper(f)
                            f.readline()  # head
                            for qid, text in plaintext.read_sv(f, ','):
                                plaintext.write_tsv(out, [[qid, text]])

                for zqrelf, qrelf in qrels_files.items():
                    if force or not os.path.exists(qrelf):
                        with zipf.open(zqrelf) as f, open(qrelf, 'wt') as out:
                            f = io.TextIOWrapper(f)
                            plaintext.write_sv(out, plaintext.read_tsv(f), ' ')

                for zbm25, bm25 in theirbm25_files.items():
                    if force or not os.path.exists(bm25):
                        with zipf.open(zbm25) as f, open(bm25, 'wb') as out:
                            out.write(f.read())