def init(self, force=False): base_path = util.path_dataset(self) idxs = [self.index, self.index_stem, self.doc_store] self._init_indices_parallel(idxs, self._init_iter_collection(), force) qrels_file = os.path.join(base_path, 'qrels.robust2004.txt') if (force or not os.path.exists(qrels_file)) and self._confirm_dua(): util.download(**_FILES['qrels'], file_name=qrels_file) for fold in FOLDS: fold_qrels_file = os.path.join(base_path, f'{fold}.qrels') if (force or not os.path.exists(fold_qrels_file)): all_qrels = trec.read_qrels_dict(qrels_file) fold_qrels = { qid: dids for qid, dids in all_qrels.items() if qid in FOLDS[fold] } trec.write_qrels_dict(fold_qrels_file, fold_qrels) query_file = os.path.join(base_path, 'topics.txt') if (force or not os.path.exists(query_file)) and self._confirm_dua(): query_file_stream = util.download_stream(**_FILES['queries'], encoding='utf8') with util.finialized_file(query_file, 'wt') as f: plaintext.write_tsv(f, trec.parse_query_format(query_file_stream))
def topics_path(self, path_topics): # Save the topics if not path_topics.is_file(): with util.finialized_file(path_topics, 'wt') as f, self.topics.path.open( "rt") as query_file_stream: data = ((item, qid, text) for item, qid, text in trec.parse_query_format(query_file_stream) if qid in self.qids) plaintext.write_tsv(f, data) return path_topics
def write_missing_metrics(self, ctxt, missing_metrics): epoch = ctxt['epoch'] for metric in missing_metrics: os.makedirs(os.path.join(ctxt['base_path'], metric), exist_ok=True) path_agg = os.path.join(ctxt['base_path'], metric, 'agg.txt') path_epoch = os.path.join(ctxt['base_path'], metric, f'{epoch}.txt') with open(path_agg, 'at') as f: plaintext.write_tsv( f, [(str(epoch), str(ctxt['metrics'][metric]))]) plaintext.write_tsv(path_epoch, ctxt['metrics_by_query'][metric].items())
def init(self, force=False): base_path = util.path_dataset(self) idxs = [self.index, self.index_stem, self.doc_store] self._init_indices_parallel(idxs, self._init_iter_collection(), force) train_qrels = os.path.join(base_path, 'train.qrels.txt') valid_qrels = os.path.join(base_path, 'valid.qrels.txt') test_qrels = os.path.join(base_path, 'test.qrels.txt') if (force or not os.path.exists(train_qrels) or not os.path.exists(valid_qrels)) and self._confirm_dua(): source_stream = util.download_stream(**_FILES['qrels_2013'], encoding='utf8') source_stream2 = util.download_stream(**_FILES['qrels_2014'], encoding='utf8') with util.finialized_file(train_qrels, 'wt') as tf, \ util.finialized_file(valid_qrels, 'wt') as vf, \ util.finialized_file(test_qrels, 'wt') as Tf: for line in source_stream: cols = line.strip().split() if int(cols[0]) in VALIDATION_QIDS: vf.write(' '.join(cols) + '\n') elif int(cols[0]) in TEST_QIDS: Tf.write(' '.join(cols) + '\n') else: tf.write(' '.join(cols) + '\n') for line in source_stream2: cols = line.strip().split() if cols[0] in VALIDATION_QIDS: vf.write(' '.join(cols) + '\n') elif int(cols[0]) in TEST_QIDS: Tf.write(' '.join(cols) + '\n') else: tf.write(' '.join(cols) + '\n') all_queries = os.path.join(base_path, 'topics.txt') if (force or not os.path.exists(all_queries)) and self._confirm_dua(): source_stream = util.download_stream(**_FILES['queries_2013'], encoding='utf8') source_stream2 = util.download_stream(**_FILES['queries_2014'], encoding='utf8') train, valid = [], [] for _id, _query in trec.parse_query_mbformat(source_stream): nid = _id.replace('MB', '').strip() train.append([nid, _query]) for _id, _query in trec.parse_query_mbformat(source_stream2): nid = _id.replace('MB', '').strip() train.append([nid, _query]) plaintext.write_tsv(all_queries, train)
def init(self, force=False): idxs = [self.index, self.index_stem, self.doc_store] self._init_indices_parallel(idxs, self._init_iter_collection(), force) train_qrels = os.path.join(util.path_dataset(self), 'train.qrels.txt') valid_qrels = os.path.join(util.path_dataset(self), 'valid.qrels.txt') if (force or not os.path.exists(train_qrels) or not os.path.exists(valid_qrels)) and self._confirm_dua(): source_stream = util.download_stream( 'https://ciir.cs.umass.edu/downloads/Antique/antique-train.qrel', encoding='utf8') with util.finialized_file(train_qrels, 'wt') as tf, \ util.finialized_file(valid_qrels, 'wt') as vf: for line in source_stream: cols = line.strip().split() if cols[0] in VALIDATION_QIDS: vf.write(' '.join(cols) + '\n') else: tf.write(' '.join(cols) + '\n') train_queries = os.path.join(util.path_dataset(self), 'train.queries.txt') valid_queries = os.path.join(util.path_dataset(self), 'valid.queries.txt') if (force or not os.path.exists(train_queries) or not os.path.exists(valid_queries)) and self._confirm_dua(): source_stream = util.download_stream( 'https://ciir.cs.umass.edu/downloads/Antique/antique-train-queries.txt', encoding='utf8') train, valid = [], [] for cols in plaintext.read_tsv(source_stream): if cols[0] in VALIDATION_QIDS: valid.append(cols) else: train.append(cols) plaintext.write_tsv(train_queries, train) plaintext.write_tsv(valid_queries, valid) test_qrels = os.path.join(util.path_dataset(self), 'test.qrels.txt') if (force or not os.path.exists(test_qrels)) and self._confirm_dua(): util.download( 'https://ciir.cs.umass.edu/downloads/Antique/antique-test.qrel', test_qrels) test_queries = os.path.join(util.path_dataset(self), 'test.queries.txt') if (force or not os.path.exists(test_queries)) and self._confirm_dua(): util.download( 'https://ciir.cs.umass.edu/downloads/Antique/antique-test-queries.txt', test_queries)
def init(self, force=False): base_dir = os.path.join(util.path_dataset(self), self.subset) if self.subset == 'dummy': datafile = os.path.join(base_dir, 'datafile.tsv') qrels = os.path.join(base_dir, 'qrels.txt') if not os.path.exists(datafile): os.symlink(os.path.abspath('etc/dummy_datafile.tsv'), datafile) if not os.path.exists(qrels): os.symlink(os.path.abspath('etc/dummy_qrels.txt'), qrels) needs_datafile = [] if force or not self.index.built(): needs_datafile.append(lambda it: self.index.build( indices.RawDoc(did, txt) for t, did, txt in it if t == 'doc')) if force or not self.index_stem.built(): needs_datafile.append(lambda it: self.index_stem.build( indices.RawDoc(did, txt) for t, did, txt in it if t == 'doc')) if force or not self.doc_store.built(): needs_datafile.append(lambda it: self.doc_store.build( indices.RawDoc(did, txt) for t, did, txt in it if t == 'doc')) query_file = os.path.join(base_dir, 'queries.tsv') if force or not os.path.exists(query_file): needs_datafile.append( lambda it: plaintext.write_tsv(query_file, ( (qid, txt) for t, qid, txt in it if t == 'query'))) if needs_datafile: df_glob = os.path.join(base_dir, 'datafile*.tsv') datafiles = glob(df_glob) while not datafiles: c = util.confirm( f'No data files found. Please move/link data files to {df_glob}.\n' 'Data files should contain both queries and documents in the ' 'following format (one per line):\n' '[query|doc] [TAB] [qid/did] [TAB] [text]') if not c: sys.exit(1) datafiles = glob(df_glob) main_iter = itertools.chain(*(plaintext.read_tsv(df) for df in datafiles)) main_iter = tqdm(main_iter, desc='reading datafiles') iters = util.blocking_tee(main_iter, len(needs_datafile)) with contextlib.ExitStack() as stack: for fn, it in zip(needs_datafile, iters): stack.enter_context( util.CtxtThread(functools.partial(fn, it))) qrels_file = os.path.join(base_dir, 'qrels.txt') while not os.path.exists(qrels_file): c = util.confirm( f'No qrels file found. Please move/link qrels file to {qrels_file}.\n' 'Qrels file should be in the TREC format:\n' '[qid] [SPACE] Q0 [SPACE] [did] [SPACE] [score]') if not c: sys.exit(1)
def _init_topics(self, subset, topic_files, qid_prefix=None, encoding=None, xml_prefix=None, force=False, expected_md5=None): topicf = os.path.join(util.path_dataset(self), f'{subset}.topics') if (force or not os.path.exists(topicf)) and self._confirm_dua(): topics = [] for topic_file in topic_files: topic_file_stream = util.download_stream( topic_file, encoding, expected_md5=expected_md5) for t, qid, text in trec.parse_query_format( topic_file_stream, xml_prefix): if qid_prefix is not None: qid = qid.replace(qid_prefix, '') topics.append((t, qid, text)) plaintext.write_tsv(topicf, topics)
def init(self, force=False): needs_docs = [] for index in [self.index_stem, self.index_stem_2020, self.doc_store]: if force or not index.built(): needs_docs.append(index) if needs_docs and self._confirm_dua(): with contextlib.ExitStack() as stack: doc_iter = self._init_iter_collection() doc_iter = self.logger.pbar(doc_iter, desc='articles') doc_iters = util.blocking_tee(doc_iter, len(needs_docs)) for idx, it in zip(needs_docs, doc_iters): if idx is self.index_stem_2020: it = (d for d in it if '2020' in d.data['date']) stack.enter_context( util.CtxtThread(functools.partial(idx.build, it))) path = os.path.join(util.path_dataset(self), 'rnd1.tsv') if not os.path.exists(path) and self._confirm_dua(): with util.download_tmp('https://ir.nist.gov/covidSubmit/data/topics-rnd1.xml', expected_md5="cf1b605222f45f7dbc90ca8e4d9b2c31") as f, \ util.finialized_file(path, 'wt') as fout: soup = BeautifulSoup(f.read(), 'lxml-xml') for topic in soup.find_all('topic'): qid = topic['number'] plaintext.write_tsv(fout, [ (qid, 'query', topic.find('query').get_text()), (qid, 'quest', topic.find('question').get_text()), (qid, 'narr', topic.find('narrative').get_text()), ]) udel_flag = path + '.includes_udel' if not os.path.exists(udel_flag): with open(path, 'at') as fout, util.finialized_file(udel_flag, 'wt'): with util.download_tmp( 'https://raw.githubusercontent.com/castorini/anserini/master/src/main/resources/topics-and-qrels/topics.covid-round1-udel.xml', expected_md5="2915cf59ae222f0aa20b2a671f67fd7a") as f: soup = BeautifulSoup(f.read(), 'lxml-xml') for topic in soup.find_all('topic'): qid = topic['number'] plaintext.write_tsv(fout, [ (qid, 'udel', topic.find('query').get_text()), ]) path = os.path.join(util.path_dataset(self), 'rnd2.tsv') if not os.path.exists(path) and self._confirm_dua(): with util.download_tmp('https://ir.nist.gov/covidSubmit/data/topics-rnd2.xml', expected_md5="550129e71c83de3fb4d6d29a172c5842") as f, \ util.finialized_file(path, 'wt') as fout: soup = BeautifulSoup(f.read(), 'lxml-xml') for topic in soup.find_all('topic'): qid = topic['number'] plaintext.write_tsv(fout, [ (qid, 'query', topic.find('query').get_text()), (qid, 'quest', topic.find('question').get_text()), (qid, 'narr', topic.find('narrative').get_text()), ]) udel_flag = path + '.includes_udel' if not os.path.exists(udel_flag): with open(path, 'at') as fout, util.finialized_file(udel_flag, 'wt'): with util.download_tmp( 'https://raw.githubusercontent.com/castorini/anserini/master/src/main/resources/topics-and-qrels/topics.covid-round2-udel.xml', expected_md5="a8988734e6f812921d5125249c197985") as f: soup = BeautifulSoup(f.read(), 'lxml-xml') for topic in soup.find_all('topic'): qid = topic['number'] plaintext.write_tsv(fout, [ (qid, 'udel', topic.find('query').get_text()), ]) path = os.path.join(util.path_dataset(self), 'rnd1.qrels') if not os.path.exists(path) and self._confirm_dua(): util.download( 'https://ir.nist.gov/covidSubmit/data/qrels-rnd1.txt', path, expected_md5="d58586df5823e7d1d0b3619a73b31518")
def __setitem__(self, epoch, value): self.content[epoch] = float(value) with open(self.path, 'at') as f: plaintext.write_tsv(f, [(str(epoch), str(value))])
def wrapped(it): with util.finialized_file(file, 'wt') as f: for doc in it: if is_heldout == (doc.did in _HELD_OUT_IDS): plaintext.write_tsv(f, [(doc.did, doc.data['headline'])])
def __call__(self, ctxt): cached = True epoch = ctxt['epoch'] base_path = os.path.join(ctxt['base_path'], self.pred.dataset.path_segment()) if self.pred.config[ 'source'] == 'run' and self.pred.config['run_threshold'] > 0: base_path = '{p}_runthreshold-{run_threshold}'.format( p=base_path, **self.pred.config) os.makedirs(os.path.join(base_path, 'runs'), exist_ok=True) with open(os.path.join(base_path, 'config.json'), 'wt') as f: json.dump(self.pred.dataset.config, f) run_path = os.path.join(base_path, 'runs', f'{epoch}.run') if os.path.exists(run_path): run = trec.read_run_dict(run_path) else: if self.pred.config['source'] == 'run' and self.pred.config[ 'run_threshold'] > 0: official_run = self.pred.dataset.run('dict') else: official_run = {} run = {} ranker = ctxt['ranker']().to(self.device) this_qid = None these_docs = {} with util.finialized_file(run_path, 'wt') as f: for qid, did, score in self.pred.iter_scores( ranker, self.datasource, self.device): if qid != this_qid: if this_qid is not None: these_docs = self._apply_threshold( these_docs, official_run.get(this_qid, {})) trec.write_run_dict(f, {this_qid: these_docs}) this_qid = qid these_docs = {} these_docs[did] = score if this_qid is not None: these_docs = self._apply_threshold( these_docs, official_run.get(this_qid, {})) trec.write_run_dict(f, {this_qid: these_docs}) cached = False result = { 'epoch': epoch, 'run': run, 'run_path': run_path, 'base_path': base_path, 'cached': cached } result['metrics'] = { m: None for m in self.pred.config['measures'].split(',') if m } result['metrics_by_query'] = {m: None for m in result['metrics']} missing_metrics = self.load_metrics(result) if missing_metrics: measures = set(missing_metrics) result['cached'] = False qrels = self.pred.dataset.qrels() calculated_metrics = onir.metrics.calc(qrels, run_path, measures) result['metrics_by_query'].update(calculated_metrics) result['metrics'].update(onir.metrics.mean(calculated_metrics)) self.write_missing_metrics(result, missing_metrics) try: if ctxt['ranker']().config.get('add_runscore'): result['metrics']['runscore_alpha'] = torch.sigmoid( ctxt['ranker']().runscore_alpha).item() rs_alpha_f = os.path.join(ctxt['base_path'], 'runscore_alpha.txt') with open(rs_alpha_f, 'at') as f: plaintext.write_tsv(rs_alpha_f, [ (str(epoch), str(result['metrics']['runscore_alpha'])) ]) except FileNotFoundError: pass # model may no longer exist, ignore return result
def _init_queryfile(self, in_stream, out_path, force=False): if force or not os.path.exists(out_path): with util.finialized_file(out_path, 'wt') as out: with self.logger.duration(f'extracting to {out_path}'): for qid, headings in car.iter_queries(in_stream): plaintext.write_tsv(out, [(qid, *headings)])
def init(self, force=False): idxs = [self.index_stem, self.doc_store] self._init_indices_parallel(idxs, self._init_iter_collection(), force) base_path = util.path_dataset(self) needs_queries = [] if force or not os.path.exists( os.path.join(base_path, 'train.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'train.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.train.tsv' and qid not in MINI_DEV))) if force or not os.path.exists( os.path.join(base_path, 'minidev.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'minidev.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.train.tsv' and qid in MINI_DEV))) if force or not os.path.exists( os.path.join(base_path, 'dev.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'dev.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.dev.tsv'))) if force or not os.path.exists( os.path.join(base_path, 'eval.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'eval.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.eval.tsv'))) if needs_queries and self._confirm_dua(): with util.download_tmp(_SOURCES['queries']) as f, \ tarfile.open(fileobj=f) as tarf, \ contextlib.ExitStack() as ctxt: def _extr_subf(subf): for qid, txt in plaintext.read_tsv( io.TextIOWrapper(tarf.extractfile(subf))): yield subf, qid, txt query_iter = [ _extr_subf('queries.train.tsv'), _extr_subf('queries.dev.tsv'), _extr_subf('queries.eval.tsv') ] query_iter = tqdm(itertools.chain(*query_iter), desc='queries') query_iters = util.blocking_tee(query_iter, len(needs_queries)) for fn, it in zip(needs_queries, query_iters): ctxt.enter_context( util.CtxtThread(functools.partial(fn, it))) file = os.path.join(base_path, 'train.qrels') if (force or not os.path.exists(file)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['train-qrels'], 'utf8') with util.finialized_file(file, 'wt') as out: for qid, _, did, score in plaintext.read_tsv(stream): if qid not in MINI_DEV: trec.write_qrels(out, [(qid, did, score)]) file = os.path.join(base_path, 'minidev.qrels') if (force or not os.path.exists(file)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['train-qrels'], 'utf8') with util.finialized_file(file, 'wt') as out: for qid, _, did, score in plaintext.read_tsv(stream): if qid in MINI_DEV: trec.write_qrels(out, [(qid, did, score)]) file = os.path.join(base_path, 'dev.qrels') if (force or not os.path.exists(file)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['dev-qrels'], 'utf8') with util.finialized_file(file, 'wt') as out: for qid, _, did, score in plaintext.read_tsv(stream): trec.write_qrels(out, [(qid, did, score)]) file = os.path.join(base_path, 'train.mspairs.gz') if not os.path.exists(file) and os.path.exists( os.path.join(base_path, 'qidpidtriples.train.full')): # legacy os.rename(os.path.join(base_path, 'qidpidtriples.train.full'), file) if (force or not os.path.exists(file)) and self._confirm_dua(): util.download(_SOURCES['qidpidtriples.train.full'], file) if not self.config['init_skip_msrun']: for file_name, subf in [('dev.msrun', 'top1000.dev'), ('eval.msrun', 'top1000.eval'), ('train.msrun', 'top1000.train.txt')]: file = os.path.join(base_path, file_name) if (force or not os.path.exists(file)) and self._confirm_dua(): run = {} with util.download_tmp(_SOURCES[file_name]) as f, \ tarfile.open(fileobj=f) as tarf: for qid, did, _, _ in tqdm( plaintext.read_tsv( io.TextIOWrapper(tarf.extractfile(subf)))): if qid not in run: run[qid] = {} run[qid][did] = 0. if file_name == 'train.msrun': minidev = { qid: dids for qid, dids in run.items() if qid in MINI_DEV } with self.logger.duration('writing minidev.msrun'): trec.write_run_dict( os.path.join(base_path, 'minidev.msrun'), minidev) run = { qid: dids for qid, dids in run.items() if qid not in MINI_DEV } with self.logger.duration(f'writing {file_name}'): trec.write_run_dict(file, run) query_path = os.path.join(base_path, 'trec2019.queries.tsv') if (force or not os.path.exists(query_path)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['trec2019.queries'], 'utf8') plaintext.write_tsv(query_path, plaintext.read_tsv(stream)) msrun_path = os.path.join(base_path, 'trec2019.msrun') if (force or not os.path.exists(msrun_path)) and self._confirm_dua(): run = {} with util.download_stream(_SOURCES['trec2019.msrun'], 'utf8') as stream: for qid, did, _, _ in plaintext.read_tsv(stream): if qid not in run: run[qid] = {} run[qid][did] = 0. with util.finialized_file(msrun_path, 'wt') as f: trec.write_run_dict(f, run) qrels_path = os.path.join(base_path, 'trec2019.qrels') if not os.path.exists(qrels_path) and self._confirm_dua(): util.download(_SOURCES['trec2019.qrels'], qrels_path) qrels_path = os.path.join(base_path, 'judgedtrec2019.qrels') if not os.path.exists(qrels_path): os.symlink('trec2019.qrels', qrels_path) query_path = os.path.join(base_path, 'judgedtrec2019.queries.tsv') judged_qids = util.Lazy( lambda: trec.read_qrels_dict(qrels_path).keys()) if (force or not os.path.exists(query_path)): with util.finialized_file(query_path, 'wt') as f: for qid, qtext in plaintext.read_tsv( os.path.join(base_path, 'trec2019.queries.tsv')): if qid in judged_qids(): plaintext.write_tsv(f, [(qid, qtext)]) msrun_path = os.path.join(base_path, 'judgedtrec2019.msrun') if (force or not os.path.exists(msrun_path)) and self._confirm_dua(): with util.finialized_file(msrun_path, 'wt') as f: for qid, dids in trec.read_run_dict( os.path.join(base_path, 'trec2019.msrun')).items(): if qid in judged_qids(): trec.write_run_dict(f, {qid: dids}) # A subset of dev that only contains queries that have relevance judgments judgeddev_path = os.path.join(base_path, 'judgeddev') judged_qids = util.Lazy(lambda: trec.read_qrels_dict( os.path.join(base_path, 'dev.qrels')).keys()) if not os.path.exists(f'{judgeddev_path}.qrels'): os.symlink('dev.qrels', f'{judgeddev_path}.qrels') if not os.path.exists(f'{judgeddev_path}.queries.tsv'): with util.finialized_file(f'{judgeddev_path}.queries.tsv', 'wt') as f: for qid, qtext in plaintext.read_tsv( os.path.join(base_path, 'dev.queries.tsv')): if qid in judged_qids(): plaintext.write_tsv(f, [(qid, qtext)]) if self.config['init_skip_msrun']: if not os.path.exists(f'{judgeddev_path}.msrun'): with util.finialized_file(f'{judgeddev_path}.msrun', 'wt') as f: for qid, dids in trec.read_run_dict( os.path.join(base_path, 'dev.msrun')).items(): if qid in judged_qids(): trec.write_run_dict(f, {qid: dids}) if not self.config['init_skip_train10']: file = os.path.join(base_path, 'train10.queries.tsv') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout: for qid, qtext in self.logger.pbar( plaintext.read_tsv( os.path.join(base_path, 'train.queries.tsv')), desc='filtering queries for train10'): if int(qid) % 10 == 0: plaintext.write_tsv(fout, [(qid, qtext)]) file = os.path.join(base_path, 'train10.qrels') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.qrels'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering qrels for train10'): qid = line.split()[0] if int(qid) % 10 == 0: fout.write(line) if not self.config['init_skip_msrun']: file = os.path.join(base_path, 'train10.msrun') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.msrun'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering msrun for train10'): qid = line.split()[0] if int(qid) % 10 == 0: fout.write(line) file = os.path.join(base_path, 'train10.mspairs.gz') if not os.path.exists(file): with gzip.open(file, 'wt') as fout, gzip.open( os.path.join(base_path, 'train.mspairs.gz'), 'rt') as fin: for qid, did1, did2 in self.logger.pbar( plaintext.read_tsv(fin), desc='filtering mspairs for train10'): if int(qid) % 10 == 0: plaintext.write_tsv(fout, [(qid, did1, did2)]) if not self.config['init_skip_train_med']: med_qids = util.Lazy( lambda: { qid.strip() for qid in util.download_stream( 'https://raw.githubusercontent.com/Georgetown-IR-Lab/covid-neural-ir/master/med-msmarco-train.txt', 'utf8', expected_md5="dc5199de7d4a872c361f89f08b1163ef") }) file = os.path.join(base_path, 'train_med.queries.tsv') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout: for qid, qtext in self.logger.pbar( plaintext.read_tsv( os.path.join(base_path, 'train.queries.tsv')), desc='filtering queries for train_med'): if qid in med_qids(): plaintext.write_tsv(fout, [(qid, qtext)]) file = os.path.join(base_path, 'train_med.qrels') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.qrels'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering qrels for train_med'): qid = line.split()[0] if qid in med_qids(): fout.write(line) if not self.config['init_skip_msrun']: file = os.path.join(base_path, 'train_med.msrun') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.msrun'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering msrun for train_med'): qid = line.split()[0] if qid in med_qids(): fout.write(line) file = os.path.join(base_path, 'train_med.mspairs.gz') if not os.path.exists(file): with gzip.open(file, 'wt') as fout, gzip.open( os.path.join(base_path, 'train.mspairs.gz'), 'rt') as fin: for qid, did1, did2 in self.logger.pbar( plaintext.read_tsv(fin), desc='filtering mspairs for train_med'): if qid in med_qids(): plaintext.write_tsv(fout, [(qid, did1, did2)])
def _init_collection(self, collection, force=False): base_path = util.path_dataset(self) if collection == '1k': idxs = [self.index1k, self.index1k_stem, self.docstore1k] elif collection == '59k': idxs = [self.index59k, self.index59k_stem, self.docstore59k] else: raise ValueError(f'unsupported collection {collection}') query_files = { f'wikIR{collection}/training/queries.csv': os.path.join(base_path, f'train.{collection}.queries'), f'wikIR{collection}/validation/queries.csv': os.path.join(base_path, f'dev.{collection}.queries'), f'wikIR{collection}/test/queries.csv': os.path.join(base_path, f'test.{collection}.queries') } qrels_files = { f'wikIR{collection}/training/qrels': os.path.join(base_path, f'train.{collection}.qrels'), f'wikIR{collection}/validation/qrels': os.path.join(base_path, f'dev.{collection}.qrels'), f'wikIR{collection}/test/qrels': os.path.join(base_path, f'test.{collection}.qrels') } theirbm25_files = { f'wikIR{collection}/training/BM25.res': os.path.join(base_path, f'train.{collection}.theirbm25'), f'wikIR{collection}/validation/BM25.res': os.path.join(base_path, f'dev.{collection}.theirbm25'), f'wikIR{collection}/test/BM25.res': os.path.join(base_path, f'test.{collection}.theirbm25') } if not force and \ all(i.built() for i in idxs) and \ all(os.path.exists(f) for f in query_files.values()) and \ all(os.path.exists(f) for f in qrels_files.values()) and \ all(os.path.exists(f) for f in theirbm25_files.values()): return if not self._confirm_dua(): return with util.download_tmp(_SOURCES[collection]) as f: with zipfile.ZipFile(f) as zipf: doc_iter = self._init_iter_collection(zipf, collection) self._init_indices_parallel(idxs, doc_iter, force) for zqueryf, queryf in query_files.items(): if force or not os.path.exists(queryf): with zipf.open(zqueryf) as f, open(queryf, 'wt') as out: f = io.TextIOWrapper(f) f.readline() # head for qid, text in plaintext.read_sv(f, ','): plaintext.write_tsv(out, [[qid, text]]) for zqrelf, qrelf in qrels_files.items(): if force or not os.path.exists(qrelf): with zipf.open(zqrelf) as f, open(qrelf, 'wt') as out: f = io.TextIOWrapper(f) plaintext.write_sv(out, plaintext.read_tsv(f), ' ') for zbm25, bm25 in theirbm25_files.items(): if force or not os.path.exists(bm25): with zipf.open(zbm25) as f, open(bm25, 'wb') as out: out.write(f.read())