예제 #1
0
 def _load_topics(self):
     result = {}
     for item, qid, text in plaintext.read_tsv(
             str(self.assessed_topics.topics_path())):
         if item == 'topic':
             result[qid] = text
     return result
예제 #2
0
 def pair_iter(self, fields, pos_source='intersect', neg_source='run', sampling='query', pos_minrel=1, unjudged_rel=0, num_neg=1, random=None, inf=False):
     special = self.config['special']
     if special == '':
         raise NotImplementedError
     assert pos_minrel == 1, f"{special} only supports pos_minrel=1"
     assert unjudged_rel == 0, f"{special} only supports unjudged_rel=1"
     assert num_neg == 1, f"{special} only supports num_neg=1"
     assert self.config['subset'] in ('train', 'train10'), f"{special} only supported with subset=train[10]"
     self.logger.warn(f'Using {special}; ingoring pair_iter arguments pos_source={pos_source} neg_source={neg_source} sampling={sampling}')
     first = True
     while first or inf:
         first = False
         if special == 'mspairs':
             f = gzip.open(os.path.join(util.path_dataset(self), '{subset}.mspairs.gz'.format(**self.config)), 'rt')
         else:
             raise ValueError(f'unsupported special={special}')
         with f:
             for qid, pos_did, neg_did in plaintext.read_tsv(f):
                 if qid in MINI_DEV:
                     continue
                 result = {f: [] for f in fields}
                 for did in [pos_did, neg_did]:
                     record = self.build_record(fields, query_id=qid, doc_id=did)
                     for f in fields:
                         result[f].append(record[f])
                 yield result
예제 #3
0
파일: robust.py 프로젝트: tgeral68/OpenNIR
 def _load_topics(self):
     result = {}
     for item, qid, text in plaintext.read_tsv(
             os.path.join(util.path_dataset(self), 'topics.txt')):
         if item == 'topic':
             result[qid] = text
     return result
예제 #4
0
    def init(self, force=False):
        base_dir = os.path.join(util.path_dataset(self), self.subset)

        if self.subset == 'dummy':
            datafile = os.path.join(base_dir, 'datafile.tsv')
            qrels = os.path.join(base_dir, 'qrels.txt')
            if not os.path.exists(datafile):
                os.symlink(os.path.abspath('etc/dummy_datafile.tsv'), datafile)
            if not os.path.exists(qrels):
                os.symlink(os.path.abspath('etc/dummy_qrels.txt'), qrels)

        needs_datafile = []
        if force or not self.index.built():
            needs_datafile.append(lambda it: self.index.build(
                indices.RawDoc(did, txt) for t, did, txt in it if t == 'doc'))

        if force or not self.index_stem.built():
            needs_datafile.append(lambda it: self.index_stem.build(
                indices.RawDoc(did, txt) for t, did, txt in it if t == 'doc'))

        if force or not self.doc_store.built():
            needs_datafile.append(lambda it: self.doc_store.build(
                indices.RawDoc(did, txt) for t, did, txt in it if t == 'doc'))

        query_file = os.path.join(base_dir, 'queries.tsv')
        if force or not os.path.exists(query_file):
            needs_datafile.append(
                lambda it: plaintext.write_tsv(query_file, (
                    (qid, txt) for t, qid, txt in it if t == 'query')))

        if needs_datafile:
            df_glob = os.path.join(base_dir, 'datafile*.tsv')
            datafiles = glob(df_glob)
            while not datafiles:
                c = util.confirm(
                    f'No data files found. Please move/link data files to {df_glob}.\n'
                    'Data files should contain both queries and documents in the '
                    'following format (one per line):\n'
                    '[query|doc] [TAB] [qid/did] [TAB] [text]')
                if not c:
                    sys.exit(1)
                datafiles = glob(df_glob)
            main_iter = itertools.chain(*(plaintext.read_tsv(df)
                                          for df in datafiles))
            main_iter = tqdm(main_iter, desc='reading datafiles')
            iters = util.blocking_tee(main_iter, len(needs_datafile))
            with contextlib.ExitStack() as stack:
                for fn, it in zip(needs_datafile, iters):
                    stack.enter_context(
                        util.CtxtThread(functools.partial(fn, it)))

        qrels_file = os.path.join(base_dir, 'qrels.txt')
        while not os.path.exists(qrels_file):
            c = util.confirm(
                f'No qrels file found. Please move/link qrels file to {qrels_file}.\n'
                'Qrels file should be in the TREC format:\n'
                '[qid] [SPACE] Q0 [SPACE] [did] [SPACE] [score]')
            if not c:
                sys.exit(1)
예제 #5
0
 def _load_queries_base(self, subset):
     querysource = self.querysource
     query_path = os.path.join(util.path_dataset(self), f'{subset}.topics')
     return {
         qid: text
         for t, qid, text in plaintext.read_tsv(query_path)
         if t == querysource
     }
예제 #6
0
 def _init_doctttttquery_iter(self):
     with util.download_tmp(_SOURCES['doctttttquery-predictions'], expected_md5=_HASHES['doctttttquery-predictions']) as f1, \
          util.download_tmp(_SOURCES['collection'], expected_md5=_HASHES['collection']) as f2:
         with zipfile.ZipFile(f1) as zipf, tarfile.open(fileobj=f2) as tarf:
             collection_stream = io.TextIOWrapper(tarf.extractfile('collection.tsv'))
             d5_iter = self._init_doctttttquery_zipf_iter(zipf)
             for (did, text), d5text in self.logger.pbar(zip(plaintext.read_tsv(collection_stream), d5_iter), desc='documents'):
                 yield indices.RawDoc(did, f'{text} {d5text}')
예제 #7
0
파일: msmarco.py 프로젝트: kiminh/OpenNIR
 def _init_iter_collection(self):
     with util.download_tmp(_SOURCES['collection']) as f:
         with tarfile.open(fileobj=f) as tarf:
             collection_stream = io.TextIOWrapper(
                 tarf.extractfile('collection.tsv'))
             for did, text in self.logger.pbar(
                     plaintext.read_tsv(collection_stream),
                     desc='documents'):
                 yield indices.RawDoc(did, text)
예제 #8
0
 def _load_queries_base(self, subset):
     rnd, fields = subset.split('-', 1)
     fields = fields.split('-')
     path = os.path.join(util.path_dataset(self), f'{rnd}.tsv')
     return {
         qid: qtext
         for qid, qtype, qtext in plaintext.read_tsv(path)
         if qtype in fields
     }
예제 #9
0
 def __init__(self, path):
     self.path = path
     if os.path.exists(path):
         self.content = {
             int(e): float(v)
             for e, v in plaintext.read_tsv(path)
         }
     else:
         self.content = {}
예제 #10
0
파일: covid.py 프로젝트: tgeral68/OpenNIR
 def _load_queries_base(self, subset):
     fold, rnd, fields = subset.split('-', 2)
     fields = fields.split('-')
     path = os.path.join(util.path_dataset(self), f'rnd5.tsv')
     filter_queries = {
         str(qid): qtext
         for qid, qtype, qtext in plaintext.read_tsv(path)
         if qtype in fields and str(qid) in FOLDS[fold]
     }
     return filter_queries
예제 #11
0
 def load_metrics(self, ctxt):
     missing = set()
     epoch = ctxt['epoch']
     for metric in list(ctxt['metrics']):
         path_agg = os.path.join(ctxt['base_path'], metric, 'agg.txt')
         path_epoch = os.path.join(ctxt['base_path'], metric,
                                   f'{epoch}.txt')
         if os.path.exists(path_agg) and os.path.exists(path_epoch):
             ctxt['metrics'][metric] = [
                 float(v) for k, v in plaintext.read_tsv(path_agg)
                 if int(k) == epoch
             ][0]
             ctxt['metrics_by_query'][metric] = {
                 k: float(v)
                 for k, v in plaintext.read_tsv(path_epoch)
             }
         else:
             missing.add(metric)
     return missing
예제 #12
0
 def _load_topics(self, subset):
     result = {}
     for qid, text in plaintext.read_tsv(
             os.path.join(util.path_dataset(self), 'topics.txt')):
         #nqid=int(qid.replace('MB','').strip())
         if subset == 'valid' and (int(qid) in VALIDATION_QIDS):
             result[qid] = text
         elif subset == 'test' and (int(qid) in TEST_QIDS):
             result[qid] = text
         elif subset == 'train' and (int(qid) not in VALIDATION_QIDS) and (
                 int(qid) not in TEST_QIDS):
             result[qid] = text
     return result
예제 #13
0
 def wrapped(logger, get_kernels=False):
     with tempfile.TemporaryDirectory() as p:
         if not get_kernels:
             vocab_path = os.path.join(p, 'vocab')
             with logger.duration(f'downloading {base_url}vocab'):
                 util.download(base_url + 'vocab', vocab_path)
             with logger.duration(f'reading vocab'):
                 v = {}
                 for term, idx in plaintext.read_tsv(vocab_path):
                     v[int(idx)] = term
                 terms = [None] * (max(v.keys()) + 1)
                 for idx, term in v.items():
                     terms[idx] = term
             embedding_path = os.path.join(p, 'embedding')
             with logger.duration(f'downloading {base_url}embedding'):
                 util.download(base_url + 'embedding', embedding_path)
             with logger.duration(f'reading embedding'):
                 weights = None
                 for values in plaintext.read_sv(embedding_path, sep=' '):
                     if len(values) == 2:
                         weights = np.ndarray(
                             (int(values[0]), int(values[1])))
                     else:
                         idx, values = values[0], values[1:]
                         weights[int(idx)] = [float(v) for v in values]
             return terms, weights
         else:  # get_kernels
             w, b = [], []
             for f in range(1, 4):
                 url = f'{base_url}filter{f}'
                 path = os.path.join(p, f'filter{f}')
                 with logger.duration(f'downloading {url}'):
                     util.download(url, path)
                 with logger.duration(f'reading filter{f}'):
                     weights, biases = None, None
                     for i, values in enumerate(
                             plaintext.read_sv(path, sep=' ')):
                         if i == 0:
                             weights = np.ndarray(
                                 (int(values[0]) * int(values[1]),
                                  int(values[2])))
                         elif i == 1:
                             biases = np.array([float(v) for v in values])
                         else:
                             weights[:, i -
                                     2] = [float(v) for v in values if v]
                 weights = weights.reshape(f, -1, weights.shape[1])
                 weights = np.transpose(weights, (2, 1, 0))
                 w.append(weights)
                 b.append(biases)
             return w, b
예제 #14
0
    def init(self, force=False):
        idxs = [self.index, self.index_stem, self.doc_store]
        self._init_indices_parallel(idxs, self._init_iter_collection(), force)

        train_qrels = os.path.join(util.path_dataset(self), 'train.qrels.txt')
        valid_qrels = os.path.join(util.path_dataset(self), 'valid.qrels.txt')
        if (force or not os.path.exists(train_qrels)
                or not os.path.exists(valid_qrels)) and self._confirm_dua():
            source_stream = util.download_stream(
                'https://ciir.cs.umass.edu/downloads/Antique/antique-train.qrel',
                encoding='utf8')
            with util.finialized_file(train_qrels, 'wt') as tf, \
                 util.finialized_file(valid_qrels, 'wt') as vf:
                for line in source_stream:
                    cols = line.strip().split()
                    if cols[0] in VALIDATION_QIDS:
                        vf.write(' '.join(cols) + '\n')
                    else:
                        tf.write(' '.join(cols) + '\n')

        train_queries = os.path.join(util.path_dataset(self),
                                     'train.queries.txt')
        valid_queries = os.path.join(util.path_dataset(self),
                                     'valid.queries.txt')
        if (force or not os.path.exists(train_queries)
                or not os.path.exists(valid_queries)) and self._confirm_dua():
            source_stream = util.download_stream(
                'https://ciir.cs.umass.edu/downloads/Antique/antique-train-queries.txt',
                encoding='utf8')
            train, valid = [], []
            for cols in plaintext.read_tsv(source_stream):
                if cols[0] in VALIDATION_QIDS:
                    valid.append(cols)
                else:
                    train.append(cols)
            plaintext.write_tsv(train_queries, train)
            plaintext.write_tsv(valid_queries, valid)

        test_qrels = os.path.join(util.path_dataset(self), 'test.qrels.txt')
        if (force or not os.path.exists(test_qrels)) and self._confirm_dua():
            util.download(
                'https://ciir.cs.umass.edu/downloads/Antique/antique-test.qrel',
                test_qrels)

        test_queries = os.path.join(util.path_dataset(self),
                                    'test.queries.txt')
        if (force or not os.path.exists(test_queries)) and self._confirm_dua():
            util.download(
                'https://ciir.cs.umass.edu/downloads/Antique/antique-test-queries.txt',
                test_queries)
예제 #15
0
 def record_iter(self, fields, source, minrel=None, shuf=True, random=None, inf=False, run_threshold=None):
     special = self.config['special']
     if special == '':
         raise NotImplementedError
     assert minrel is None or minrel < 1
     if source != 'run':
         self.logger.warn(f'Using special={special}; ingoring record_iter arguments source={source}')
     if run_threshold is not None:
         self.logger.warn(f'Using special={special}; ingoring record_iter arguments run_threshold={run_threshold}')
     first = True
     while first or inf:
         first = False
         if special == 'mspairs':
             f = gzip.open(os.path.join(util.path_dataset(self), '{subset}.mspairs.gz'.format(**self.config)), 'rt')
             it = plaintext.read_tsv(f)
             fields = fields - {'relscore'} # don't request relscore from typical channels (i.e., qrels) because we already know and this is faster.
         elif special == 'msrun':
             f = os.path.join(util.path_dataset(self), '{subset}.msrun'.format(**self.config))
             it = ((qid, did) for qid, did, rank, score in trec.read_run(f))
         elif special == 'validrun':
             f = os.path.join(util.path_dataset(self), '{subset}.validrun'.format(**self.config))
             it = plaintext.read_sv(f, ' ')
         else:
             raise ValueError(f'unsupported special={special}')
         if shuf:
             if special in ('msrun', 'mspairs'):
                 self.logger.warn(f'ignoring shuf=True with special={special}')
             else:
                 it = list(it)
                 random.shuffle(it)
         for cols in it:
             if len(cols) == 3:
                 qid, pos_did, neg_did = cols
                 dids = [pos_did, neg_did] if (minrel is None or minrel <= 0) else [pos_did]
                 if qid in MINI_DEV:
                     continue
             elif len(cols) == 2:
                 qid, did = cols
                 dids = [did]
             for did in dids:
                 record = self.build_record(fields, query_id=qid, doc_id=did)
                 result = {f: record[f] for f in fields}
                 if len(cols) == 3:
                     result['relscore'] = (1 if did == pos_did else 0)
                 yield result
예제 #16
0
 def test_build(self):
     df = plaintext.read_tsv('etc/dummy_datafile.tsv')
     docs = [indices.RawDoc(did, dtext) for t, did, dtext in df if t == 'doc']
     with tempfile.TemporaryDirectory() as tmpdir:
         idxs = [
             (indices.AnseriniIndex(os.path.join(tmpdir, 'anserini')), False),
             (indices.AnseriniIndex(os.path.join(tmpdir, 'anserini.rawdocs'), store_raw_docs=True), True),
             (indices.SqliteDocstore(os.path.join(tmpdir, 'sqlite')), True),
         ]
         for index, check_raw_docs in idxs:
             with self.subTest(index=index):
                 self.assertFalse(index.built())
                 index.build(iter(docs))
                 self.assertTrue(index.built())
                 self.assertEqual(index.num_docs(), len(docs))
                 if check_raw_docs:
                     for doc in docs:
                         self.assertEqual(index.get_raw(doc.did), doc.data['text'])
예제 #17
0
 def test_batch_query(self):
     df = list(plaintext.read_tsv('etc/dummy_datafile.tsv'))
     docs = [indices.RawDoc(did, dtext) for t, did, dtext in df if t == 'doc']
     queries = [(qid, qtext) for t, qid, qtext in df if t == 'query']
     with tempfile.TemporaryDirectory() as tmpdir:
         idxs = [
             indices.AnseriniIndex(os.path.join(tmpdir, 'anserini')),
         ]
         models = [
             'bm25', 'bm25_k1-1.5', 'bm25_b-0.2', 'bm25_k1-1.6_b-0.8',
             'bm25_rm3', 'bm25_rm3_k1-1.5', 'bm25_rm3_b-0.2', 'bm25_rm3_k1-1.6_b-0.8',
             'bm25_rm3_rm3.fbTerms-2_rm3.fbDocs-2', 'bm25_rm3_rm3.fbTerms-2_rm3.fbDocs-2_k1-1.5',
             'bm25_rm3_rm3.fbTerms-2_rm3.fbDocs-2_b-0.2', 'bm25_rm3_rm3.fbTerms-2_rm3.fbDocs-2_k1-1.6_b-0.8',
             'ql', 'ql_mu-0.4',
             'sdm', 'sdm_uw-0.3_ow-0.2_tw-0.5',
         ]
         for index in idxs:
             index.build(docs)
             for model in models:
                 with self.subTest(index=index, model=model):
                     index.batch_query(queries, model, topk=10)
                     index.batch_query(queries, model, topk=10, quiet=True)
예제 #18
0
    def run(self):
        if self.config['queries']:
            logger.debug(
                'loading queries from {queries}'.format(**self.config))
            query_iter = plaintext.read_tsv(self.config['queries'])
        else:
            logger.debug('loading queries test_ds')
            query_iter = self.test_ds.all_queries_raw()

        if self.config['rerank']:
            if not self.config['dvec_file']:
                raise ValueError('must provide dvec_file')
            self._load_ranker_weights(self.ranker, self.vocab, self.trainer,
                                      self.valid_pred, self.train_ds)
            self.ranker.eval()
            input_spec = self.ranker.input_spec()
            fields = {
                f
                for f in input_spec['fields'] if f.startswith('query_')
            }
            device = util.device(self.config, logger)
            vocab_size = self.vocab.lexicon_size()
            num_docs = self.test_ds.num_docs()
            dvec_cache = EpicCacheReader(self.config['dvec_file'],
                                         self.config['prune'], num_docs,
                                         vocab_size, self.config['dvec_inmem'],
                                         self.config['gpu'])
        else:
            pass  # only do initial retrieval

        self.timer = util.DurationTimer(gpu_sync=self.config['gpu'])
        with torch.no_grad():
            if self.config['mode'] == 'time':
                self.time(query_iter, dvec_cache, fields, input_spec, device)
            if self.config['mode'] == 'predict':
                self.predict(query_iter, dvec_cache, fields, input_spec,
                             device)
예제 #19
0
    def _init_collection(self, collection, force=False):
        base_path = util.path_dataset(self)
        if collection == '1k':
            idxs = [self.index1k, self.index1k_stem, self.docstore1k]
        elif collection == '59k':
            idxs = [self.index59k, self.index59k_stem, self.docstore59k]
        else:
            raise ValueError(f'unsupported collection {collection}')

        query_files = {
            f'wikIR{collection}/training/queries.csv':
            os.path.join(base_path, f'train.{collection}.queries'),
            f'wikIR{collection}/validation/queries.csv':
            os.path.join(base_path, f'dev.{collection}.queries'),
            f'wikIR{collection}/test/queries.csv':
            os.path.join(base_path, f'test.{collection}.queries')
        }

        qrels_files = {
            f'wikIR{collection}/training/qrels':
            os.path.join(base_path, f'train.{collection}.qrels'),
            f'wikIR{collection}/validation/qrels':
            os.path.join(base_path, f'dev.{collection}.qrels'),
            f'wikIR{collection}/test/qrels':
            os.path.join(base_path, f'test.{collection}.qrels')
        }

        theirbm25_files = {
            f'wikIR{collection}/training/BM25.res':
            os.path.join(base_path, f'train.{collection}.theirbm25'),
            f'wikIR{collection}/validation/BM25.res':
            os.path.join(base_path, f'dev.{collection}.theirbm25'),
            f'wikIR{collection}/test/BM25.res':
            os.path.join(base_path, f'test.{collection}.theirbm25')
        }

        if not force and \
           all(i.built() for i in idxs) and \
           all(os.path.exists(f) for f in query_files.values()) and \
           all(os.path.exists(f) for f in qrels_files.values()) and \
           all(os.path.exists(f) for f in theirbm25_files.values()):
            return

        if not self._confirm_dua():
            return

        with util.download_tmp(_SOURCES[collection]) as f:
            with zipfile.ZipFile(f) as zipf:
                doc_iter = self._init_iter_collection(zipf, collection)
                self._init_indices_parallel(idxs, doc_iter, force)

                for zqueryf, queryf in query_files.items():
                    if force or not os.path.exists(queryf):
                        with zipf.open(zqueryf) as f, open(queryf,
                                                           'wt') as out:
                            f = io.TextIOWrapper(f)
                            f.readline()  # head
                            for qid, text in plaintext.read_sv(f, ','):
                                plaintext.write_tsv(out, [[qid, text]])

                for zqrelf, qrelf in qrels_files.items():
                    if force or not os.path.exists(qrelf):
                        with zipf.open(zqrelf) as f, open(qrelf, 'wt') as out:
                            f = io.TextIOWrapper(f)
                            plaintext.write_sv(out, plaintext.read_tsv(f), ' ')

                for zbm25, bm25 in theirbm25_files.items():
                    if force or not os.path.exists(bm25):
                        with zipf.open(zbm25) as f, open(bm25, 'wb') as out:
                            out.write(f.read())
예제 #20
0
 def _load_queries_base(self, subset):
     result = {}
     f = os.path.join(util.path_dataset(self), f'{subset}.queries.txt')
     for qid, text in plaintext.read_tsv(f):
         result[qid] = text
     return result
예제 #21
0
파일: nyt.py 프로젝트: bpiwowar/OpenNIR-xpm
 def _load_queries_base(self, subset):
     with logger.duration('loading queries'):
         base_path = util.path_dataset(self)
         path = os.path.join(base_path, f'{subset}.queries')
         return dict(plaintext.read_tsv(path))
예제 #22
0
파일: msmarco.py 프로젝트: kiminh/OpenNIR
 def _load_queries_base(self, subset):
     base_path = util.path_dataset(self)
     path = os.path.join(base_path, f'{subset}.queries.tsv')
     return dict(
         self.logger.pbar(plaintext.read_tsv(path), desc='loading queries'))
예제 #23
0
파일: msmarco.py 프로젝트: kiminh/OpenNIR
    def init(self, force=False):
        idxs = [self.index_stem, self.doc_store]
        self._init_indices_parallel(idxs, self._init_iter_collection(), force)

        base_path = util.path_dataset(self)

        needs_queries = []
        if force or not os.path.exists(
                os.path.join(base_path, 'train.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'train.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.train.tsv' and qid not in MINI_DEV)))
        if force or not os.path.exists(
                os.path.join(base_path, 'minidev.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'minidev.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.train.tsv' and qid in MINI_DEV)))
        if force or not os.path.exists(
                os.path.join(base_path, 'dev.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'dev.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.dev.tsv')))
        if force or not os.path.exists(
                os.path.join(base_path, 'eval.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'eval.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.eval.tsv')))

        if needs_queries and self._confirm_dua():
            with util.download_tmp(_SOURCES['queries']) as f, \
                 tarfile.open(fileobj=f) as tarf, \
                 contextlib.ExitStack() as ctxt:

                def _extr_subf(subf):
                    for qid, txt in plaintext.read_tsv(
                            io.TextIOWrapper(tarf.extractfile(subf))):
                        yield subf, qid, txt

                query_iter = [
                    _extr_subf('queries.train.tsv'),
                    _extr_subf('queries.dev.tsv'),
                    _extr_subf('queries.eval.tsv')
                ]
                query_iter = tqdm(itertools.chain(*query_iter), desc='queries')
                query_iters = util.blocking_tee(query_iter, len(needs_queries))
                for fn, it in zip(needs_queries, query_iters):
                    ctxt.enter_context(
                        util.CtxtThread(functools.partial(fn, it)))

        file = os.path.join(base_path, 'train.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['train-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    if qid not in MINI_DEV:
                        trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'minidev.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['train-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    if qid in MINI_DEV:
                        trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'dev.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['dev-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'train.mspairs.gz')
        if not os.path.exists(file) and os.path.exists(
                os.path.join(base_path, 'qidpidtriples.train.full')):
            # legacy
            os.rename(os.path.join(base_path, 'qidpidtriples.train.full'),
                      file)
        if (force or not os.path.exists(file)) and self._confirm_dua():
            util.download(_SOURCES['qidpidtriples.train.full'], file)

        if not self.config['init_skip_msrun']:
            for file_name, subf in [('dev.msrun', 'top1000.dev'),
                                    ('eval.msrun', 'top1000.eval'),
                                    ('train.msrun', 'top1000.train.txt')]:
                file = os.path.join(base_path, file_name)
                if (force or not os.path.exists(file)) and self._confirm_dua():
                    run = {}
                    with util.download_tmp(_SOURCES[file_name]) as f, \
                         tarfile.open(fileobj=f) as tarf:
                        for qid, did, _, _ in tqdm(
                                plaintext.read_tsv(
                                    io.TextIOWrapper(tarf.extractfile(subf)))):
                            if qid not in run:
                                run[qid] = {}
                            run[qid][did] = 0.
                    if file_name == 'train.msrun':
                        minidev = {
                            qid: dids
                            for qid, dids in run.items() if qid in MINI_DEV
                        }
                        with self.logger.duration('writing minidev.msrun'):
                            trec.write_run_dict(
                                os.path.join(base_path, 'minidev.msrun'),
                                minidev)
                        run = {
                            qid: dids
                            for qid, dids in run.items() if qid not in MINI_DEV
                        }
                    with self.logger.duration(f'writing {file_name}'):
                        trec.write_run_dict(file, run)

        query_path = os.path.join(base_path, 'trec2019.queries.tsv')
        if (force or not os.path.exists(query_path)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['trec2019.queries'], 'utf8')
            plaintext.write_tsv(query_path, plaintext.read_tsv(stream))
        msrun_path = os.path.join(base_path, 'trec2019.msrun')
        if (force or not os.path.exists(msrun_path)) and self._confirm_dua():
            run = {}
            with util.download_stream(_SOURCES['trec2019.msrun'],
                                      'utf8') as stream:
                for qid, did, _, _ in plaintext.read_tsv(stream):
                    if qid not in run:
                        run[qid] = {}
                    run[qid][did] = 0.
            with util.finialized_file(msrun_path, 'wt') as f:
                trec.write_run_dict(f, run)

        qrels_path = os.path.join(base_path, 'trec2019.qrels')
        if not os.path.exists(qrels_path) and self._confirm_dua():
            util.download(_SOURCES['trec2019.qrels'], qrels_path)
        qrels_path = os.path.join(base_path, 'judgedtrec2019.qrels')
        if not os.path.exists(qrels_path):
            os.symlink('trec2019.qrels', qrels_path)
        query_path = os.path.join(base_path, 'judgedtrec2019.queries.tsv')
        judged_qids = util.Lazy(
            lambda: trec.read_qrels_dict(qrels_path).keys())
        if (force or not os.path.exists(query_path)):
            with util.finialized_file(query_path, 'wt') as f:
                for qid, qtext in plaintext.read_tsv(
                        os.path.join(base_path, 'trec2019.queries.tsv')):
                    if qid in judged_qids():
                        plaintext.write_tsv(f, [(qid, qtext)])
        msrun_path = os.path.join(base_path, 'judgedtrec2019.msrun')
        if (force or not os.path.exists(msrun_path)) and self._confirm_dua():
            with util.finialized_file(msrun_path, 'wt') as f:
                for qid, dids in trec.read_run_dict(
                        os.path.join(base_path, 'trec2019.msrun')).items():
                    if qid in judged_qids():
                        trec.write_run_dict(f, {qid: dids})

        # A subset of dev that only contains queries that have relevance judgments
        judgeddev_path = os.path.join(base_path, 'judgeddev')
        judged_qids = util.Lazy(lambda: trec.read_qrels_dict(
            os.path.join(base_path, 'dev.qrels')).keys())
        if not os.path.exists(f'{judgeddev_path}.qrels'):
            os.symlink('dev.qrels', f'{judgeddev_path}.qrels')
        if not os.path.exists(f'{judgeddev_path}.queries.tsv'):
            with util.finialized_file(f'{judgeddev_path}.queries.tsv',
                                      'wt') as f:
                for qid, qtext in plaintext.read_tsv(
                        os.path.join(base_path, 'dev.queries.tsv')):
                    if qid in judged_qids():
                        plaintext.write_tsv(f, [(qid, qtext)])
        if self.config['init_skip_msrun']:
            if not os.path.exists(f'{judgeddev_path}.msrun'):
                with util.finialized_file(f'{judgeddev_path}.msrun',
                                          'wt') as f:
                    for qid, dids in trec.read_run_dict(
                            os.path.join(base_path, 'dev.msrun')).items():
                        if qid in judged_qids():
                            trec.write_run_dict(f, {qid: dids})

        if not self.config['init_skip_train10']:
            file = os.path.join(base_path, 'train10.queries.tsv')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout:
                    for qid, qtext in self.logger.pbar(
                            plaintext.read_tsv(
                                os.path.join(base_path, 'train.queries.tsv')),
                            desc='filtering queries for train10'):
                        if int(qid) % 10 == 0:
                            plaintext.write_tsv(fout, [(qid, qtext)])

            file = os.path.join(base_path, 'train10.qrels')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout, open(
                        os.path.join(base_path, 'train.qrels'), 'rt') as fin:
                    for line in self.logger.pbar(
                            fin, desc='filtering qrels for train10'):
                        qid = line.split()[0]
                        if int(qid) % 10 == 0:
                            fout.write(line)

            if not self.config['init_skip_msrun']:
                file = os.path.join(base_path, 'train10.msrun')
                if not os.path.exists(file):
                    with util.finialized_file(file, 'wt') as fout, open(
                            os.path.join(base_path, 'train.msrun'),
                            'rt') as fin:
                        for line in self.logger.pbar(
                                fin, desc='filtering msrun for train10'):
                            qid = line.split()[0]
                            if int(qid) % 10 == 0:
                                fout.write(line)

            file = os.path.join(base_path, 'train10.mspairs.gz')
            if not os.path.exists(file):
                with gzip.open(file, 'wt') as fout, gzip.open(
                        os.path.join(base_path, 'train.mspairs.gz'),
                        'rt') as fin:
                    for qid, did1, did2 in self.logger.pbar(
                            plaintext.read_tsv(fin),
                            desc='filtering mspairs for train10'):
                        if int(qid) % 10 == 0:
                            plaintext.write_tsv(fout, [(qid, did1, did2)])

        if not self.config['init_skip_train_med']:
            med_qids = util.Lazy(
                lambda: {
                    qid.strip()
                    for qid in util.download_stream(
                        'https://raw.githubusercontent.com/Georgetown-IR-Lab/covid-neural-ir/master/med-msmarco-train.txt',
                        'utf8',
                        expected_md5="dc5199de7d4a872c361f89f08b1163ef")
                })
            file = os.path.join(base_path, 'train_med.queries.tsv')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout:
                    for qid, qtext in self.logger.pbar(
                            plaintext.read_tsv(
                                os.path.join(base_path, 'train.queries.tsv')),
                            desc='filtering queries for train_med'):
                        if qid in med_qids():
                            plaintext.write_tsv(fout, [(qid, qtext)])

            file = os.path.join(base_path, 'train_med.qrels')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout, open(
                        os.path.join(base_path, 'train.qrels'), 'rt') as fin:
                    for line in self.logger.pbar(
                            fin, desc='filtering qrels for train_med'):
                        qid = line.split()[0]
                        if qid in med_qids():
                            fout.write(line)

            if not self.config['init_skip_msrun']:
                file = os.path.join(base_path, 'train_med.msrun')
                if not os.path.exists(file):
                    with util.finialized_file(file, 'wt') as fout, open(
                            os.path.join(base_path, 'train.msrun'),
                            'rt') as fin:
                        for line in self.logger.pbar(
                                fin, desc='filtering msrun for train_med'):
                            qid = line.split()[0]
                            if qid in med_qids():
                                fout.write(line)

            file = os.path.join(base_path, 'train_med.mspairs.gz')
            if not os.path.exists(file):
                with gzip.open(file, 'wt') as fout, gzip.open(
                        os.path.join(base_path, 'train.mspairs.gz'),
                        'rt') as fin:
                    for qid, did1, did2 in self.logger.pbar(
                            plaintext.read_tsv(fin),
                            desc='filtering mspairs for train_med'):
                        if qid in med_qids():
                            plaintext.write_tsv(fout, [(qid, did1, did2)])
예제 #24
0
파일: msmarco.py 프로젝트: kiminh/OpenNIR
 def _extr_subf(subf):
     for qid, txt in plaintext.read_tsv(
             io.TextIOWrapper(tarf.extractfile(subf))):
         yield subf, qid, txt
예제 #25
0
 def _load_queries_base(self, current, subset):
     base_path = util.path_dataset(self)
     path = os.path.join(base_path, f'{current}-{subset}.queries.tsv')
     return dict(
         self.logger.pbar(plaintext.read_tsv(path),
                          desc=f'loading queries CURRENT={current}'))
예제 #26
0
파일: car.py 프로젝트: tgeral68/OpenNIR
 def _load_queries_base(self, subset):
     path = os.path.join(util.path_dataset(self), f'{subset}.queries.tsv')
     result = {}
     for cols in plaintext.read_tsv(path):
         result[cols[0]] = ' '.join(cols[1:])
     return result
예제 #27
0
 def _load_queries_base(self, subset):
     queries_path = os.path.join(util.path_dataset(self), subset,
                                 'queries.tsv')
     return {qid: qtext for qid, qtext in plaintext.read_tsv(queries_path)}
예제 #28
0
 def _init_iter_collection(self):
     strm = util.download_stream(
         'https://ciir.cs.umass.edu/downloads/Antique/antique-collection.txt',
         'utf8')
     for did, text in plaintext.read_tsv(strm):
         yield indices.RawDoc(did, text)
예제 #29
0
def _iter_collection(path):
    logger = log.easy()
    with path.open("rt") as collection_stream:
        for did, text in logger.pbar(plaintext.read_tsv(collection_stream),
                                     desc='documents'):
            yield indices.RawDoc(did, text)
예제 #30
0
 def _load_topics(self):
     return dict(
         self.logger.pbar(plaintext.read_tsv(
             self.assessed_topics.topics.path),
                          desc='loading queries'))