def main(args): parser = argparse.ArgumentParser( prog='ir_datasets export', description= 'Exports documents, queries, qrels, and scoreddocs in various formats.' ) parser.add_argument('dataset') parser.set_defaults(out=sys.stdout) subparsers = parser.add_subparsers(dest='data') subparsers.required = True subparser = subparsers.add_parser('docs') subparser.add_argument('--format', choices=DEFAULT_EXPORTERS.keys(), default='tsv') subparser.add_argument('--fields', nargs='+') subparser.set_defaults(fn=main_docs) subparser = subparsers.add_parser('queries') subparser.add_argument('--format', choices=DEFAULT_EXPORTERS.keys(), default='tsv') subparser.add_argument('--fields', nargs='+') subparser.set_defaults(fn=main_queries) subparser = subparsers.add_parser('qrels') subparser.add_argument('--format', choices=['trec'], default='trec') subparser.set_defaults(fn=main_qrels) subparser = subparsers.add_parser('scoreddocs') subparser.add_argument('--format', choices=['trec'], default='trec') subparser.add_argument('--runtag', default='run') subparser.set_defaults(fn=main_scoreddocs) args = parser.parse_args(args) dataset = ir_datasets.load(args.dataset) try: dataset = ir_datasets.load(args.dataset) except KeyError: sys.stderr.write(f"Dataset {args.dataset} not found.\n") sys.exit(1) try: args.fn(dataset, args) except BrokenPipeError: sys.stderr.close() except KeyboardInterrupt: sys.stderr.close() except AssertionError as e: if str(e): sys.stderr.write(str(e) + '\n') else: raise
def dataset_to_collection(name): # adapted from https://github.com/Georgetown-IR-Lab/OpenNIR/blob/master/onir/datasets/irds.py#L47 # HACK: find "parent" dataset that contains same docs handler so we don't re-build the index for the same collection ds = ir_datasets.load(name) segments = name.split("/") docs_handler = ds.docs_handler() parent_docs_ds = name while len(segments) > 1: segments = segments[:-1] parent_ds = ir_datasets.load("/".join(segments)) if parent_ds.has_docs() and parent_ds.docs_handler() == docs_handler: parent_docs_ds = "/".join(segments) return parent_docs_ds
def test_gov2_docstore(self): docstore = ir_datasets.load('gov2').docs_store() docstore.clear_cache() with _logger.duration('cold fetch'): docstore.get_many(['GX269-06-1933735', 'GX269-06-16539507', 'GX002-04-0481202']) with _logger.duration('warm fetch'): docstore.get_many(['GX269-06-1933735', 'GX269-06-16539507', 'GX002-04-0481202']) docstore = ir_datasets.load('gov2').docs_store() with _logger.duration('warm fetch (new docstore)'): docstore.get_many(['GX269-06-1933735', 'GX269-06-16539507', 'GX002-04-0481202']) with _logger.duration('cold fetch (nearby)'): docstore.get_many(['GX269-06-16476479', 'GX269-06-1939325', 'GX002-04-0587205']) with _logger.duration('cold fetch (earlier)'): docstore.get_many(['GX269-06-0125294', 'GX002-04-0050816'])
def _build_test_docs(self, dataset_name, include_count=True, include_idxs=(0, 9)): items = {} count = 0 if isinstance(dataset_name, str): dataset = ir_datasets.load(dataset_name) else: dataset = dataset_name for i, doc in enumerate( _logger.pbar(dataset.docs_iter(), f'{dataset_name} docs', unit='doc')): count += 1 if i in include_idxs: items[i] = doc if not include_count and ( (include_idxs[-1] < 1000 and i == 1000) or (include_idxs[-1] >= 1000 and i == include_idxs[-1])): break items[count - 1] = doc items = { k: self._replace_regex_namedtuple(v) for k, v in items.items() } count = f', count={count}' if include_count else '' _logger.info(f''' self._test_docs({repr(dataset_name)}{count}, items={self._repr_namedtuples(items)}) ''')
def main(args): parser = argparse.ArgumentParser( prog='ir_datasets lookup', description='Provides fast lookups of documents and queries ' 'using docs_store. Unlike using the exporter and grep (or similar), this tool builds ' 'an index for O(log(n)) lookups.') parser.add_argument('dataset') parser.set_defaults(out=sys.stdout) parser.add_argument('--format', choices=DEFAULT_EXPORTERS.keys(), default='tsv') parser.add_argument('--fields', nargs='+') parser.add_argument('--qid', '--query_id', '-q', action='store_true') parser.add_argument('ids', nargs='+') args = parser.parse_args(args) try: dataset = ir_datasets.load(args.dataset) except KeyError: sys.stderr.write(f"Dataset {args.dataset} not found.\n") sys.exit(1) if args.qid: qid_lookup(dataset, args) else: did_lookup(dataset, args)
def dataset(self): if not self.ird_dataset_name: raise ValueError("ird_dataset_name not set") if not self._dataset: self._dataset = ir_datasets.load(self.ird_dataset_name) return self._dataset
def do(): dataset = ir_datasets.load('msmarco-document/' + args.query_split) lookup = ir_datasets.wrappers.DocstoreWrapper(dataset).queries_store() with open(args.run_file, "w") as fp: for id,query in dataset.queries_iter(): docs = get_result(query) for i,d in enumerate(docs): fp.write("{}\t{}\t{}\n".format(id,d,i+1))
def main(args): parser = argparse.ArgumentParser( prog='ir_datasets generate_metadata', description='Generates metadata for the specified datasets') parser.add_argument('--file', help='output file', type=Path, default=Path('ir_datasets/etc/metadata.json')) parser.add_argument( '--datasets', nargs='+', help= 'dataset IDs for which to compute metadata. If omitted, generates for all datasets present in the registry (skipping patterns)' ) args = parser.parse_args(args) if args.file.is_file(): with args.file.open('rb') as f: data = json.load(f) else: data = {} if args.datasets: def _ds_iter(): for dsid in args.datasets: yield dsid, data.get(dsid, {}) import multiprocessing with multiprocessing.Pool(10) as pool: for dsid, dataset_metadata in _logger.pbar( pool.imap_unordered(dataset2metadata, _ds_iter()), desc='datasets', total=len(args.datasets)): if dataset_metadata is not None: data[dsid] = dataset_metadata write_metadata_file(data, args.file) else: for dsid in ir_datasets.registry._registered: dataset = ir_datasets.load(dsid) brk = False try: _, dataset_metadata = dataset2metadata( (dsid, data.get(dsid, {}))) if dataset_metadata is not None: data[dsid] = dataset_metadata except KeyboardInterrupt: _logger.info( f'KeyboardInterrupt; skipping. ctrl+c within 0.5sec to stop compute_metadata.' ) try: time.sleep(0.5) except KeyboardInterrupt: brk = True break write_metadata_file(data, args.file) if brk: break
def ird_load_qrels(self): qrels = {} for name in self.ird_dataset_names: dataset = ir_datasets.load(name) for qrel in dataset.qrels_iter(): qrels.setdefault(qrel.query_id, {}) qrels[qrel.query_id][qrel.doc_id] = max( qrel.relevance, qrels[qrel.query_id].get(qrel.doc_id, -1)) return qrels
def test_clueweb09_docstore(self): docstore = ir_datasets.load('clueweb09').docs_store() docstore.clear_cache() with _logger.duration('cold fetch'): result = docstore.get_many(['clueweb09-en0000-00-00003', 'clueweb09-en0000-00-35154', 'clueweb09-ar0000-48-02342']) self.assertEqual(len(result), 3) with _logger.duration('warm fetch'): result = docstore.get_many(['clueweb09-en0000-00-00003', 'clueweb09-en0000-00-35154', 'clueweb09-ar0000-48-02342']) self.assertEqual(len(result), 3) docstore = ir_datasets.load('clueweb09').docs_store() with _logger.duration('warm fetch (new docstore)'): result = docstore.get_many(['clueweb09-en0000-00-00003', 'clueweb09-en0000-00-35154', 'clueweb09-ar0000-48-02342']) self.assertEqual(len(result), 3) with _logger.duration('cold fetch (nearby)'): result = docstore.get_many(['clueweb09-en0000-00-00023', 'clueweb09-en0000-00-35167', 'clueweb09-ar0000-48-02348']) self.assertEqual(len(result), 3) with _logger.duration('cold fetch (earlier)'): result = docstore.get_many(['clueweb09-en0000-00-00001', 'clueweb09-ar0000-48-00009']) self.assertEqual(len(result), 2)
def _test_ds(self, dsid): with self.subTest(dsid): dataset = ir_datasets.load(dsid) metadata = dataset.metadata() for etype in ir_datasets.EntityType: if dataset.has(etype): self.assertTrue(etype.value in metadata, f"{dsid} missing {etype.value} metadata") self.assertTrue('count' in metadata[etype.value], f"{dsid} missing {etype.value} metadata")
def ird_load_topics(self): topics = {} field = "description" if self.query_type == "desc" else self.query_type for name in self.ird_dataset_names: dataset = ir_datasets.load(name) for query in dataset.queries_iter(): topics[query.query_id] = getattr(query, field).replace("\n", " ") return {self.query_type: topics}
def ird_load_qrels(self): qrels = {} for name in self.ird_dataset_names: year = name.split("-")[-1] assert len(year) == 4 dataset = ir_datasets.load(name) for qrel in dataset.qrels_iter(): qid = year + qrel.query_id qrels.setdefault(qid, {}) qrels[qid][qrel.doc_id] = max(qrel.relevance, qrels[qid].get(qrel.doc_id, -1)) return qrels
def eval(): queries = [] dataset = ir_datasets.load('msmarco-passage/' + args.query_split + '/small') for query_id, text in dataset.queries_iter(): queries.append((query_id, text)) with open(args.run_file, "w") as fp: for id, query in queries: docs = get_result(query) for i, d in enumerate(docs): fp.write("{}\t{}\t{}\n".format(id, d, i + 1))
def _build_test_docpairs(self, dataset_name): items = {} count = 0 for i, docpair in enumerate( _logger.pbar( ir_datasets.load(dataset_name).docpairs_iter(), f'{dataset_name} docpairs')): count += 1 if i in (0, 9): items[i] = docpair items[count - 1] = docpair _logger.info(f''' self._test_docpairs(i{repr(dataset_name)}, count={count}, items={self._repr_namedtuples(items)}) ''')
def _build_test_qlogs(self, dataset_name): items = {} count = 0 for i, qlog in enumerate( _logger.pbar(ir_datasets.load(dataset_name).qlogs_iter(), f'{dataset_name} qlogs', unit='qlogs')): count += 1 if i in (0, 9): items[i] = qlog items[count - 1] = qlog _logger.info(f''' self._test_qlogs({repr(dataset_name)}, count={count}, items={self._repr_namedtuples(items)}) ''')
def load_split(split): dataset = ir_datasets.load(split) docs = {} for doc_id, text in dataset.docs_iter(): docs[doc_id] = text query = {} for query_id, text in dataset.queries_iter(): query[query_id] = text split_data = [] for query_id, doc_id, rel, iteration in dataset.qrels_iter(): split_data.append([query[query_id], docs[doc_id], rel]) return pd.DataFrame(split_data, columns=['query', 'passage', 'rel'])
def test_clueweb12_docstore(self): docstore = ir_datasets.load('clueweb12').docs_store() docstore.clear_cache() with _logger.duration('cold fetch'): docstore.get_many([ 'clueweb12-0000tw-05-00014', 'clueweb12-0000tw-05-12119', 'clueweb12-0106wb-18-19516' ]) docstore.clear_cache() with _logger.duration('cold fetch (cleared)'): docstore.get_many([ 'clueweb12-0000tw-05-00014', 'clueweb12-0000tw-05-12119', 'clueweb12-0106wb-18-19516' ]) with _logger.duration('warm fetch'): docstore.get_many([ 'clueweb12-0000tw-05-00014', 'clueweb12-0000tw-05-12119', 'clueweb12-0106wb-18-19516' ]) docstore = ir_datasets.load('clueweb12').docs_store() with _logger.duration('warm fetch (new docstore)'): docstore.get_many([ 'clueweb12-0000tw-05-00014', 'clueweb12-0000tw-05-12119', 'clueweb12-0106wb-18-19516' ]) with _logger.duration('cold fetch (nearby)'): docstore.get_many([ 'clueweb12-0000tw-05-00020', 'clueweb12-0000tw-05-12201', 'clueweb12-0106wb-18-19412' ]) with _logger.duration('cold fetch (earlier)'): docstore.get_many( ['clueweb12-0000tw-05-00001', 'clueweb12-0106wb-18-08131']) docstore.clear_cache() with _logger.duration('cold fetch (earlier, cleared)'): docstore.get_many( ['clueweb12-0000tw-05-00001', 'clueweb12-0106wb-18-08131'])
def ird_load_topics(self): topics = {} field = "description" if self.query_type == "desc" else self.query_type for name in self.ird_dataset_names: year = name.split("-")[-1] assert len(year) == 4 dataset = ir_datasets.load(name) for query in dataset.queries_iter(): qid = year + query.query_id topics[qid] = getattr(query, field).replace("\n", " ") self.query_types[qid] = query.type return {self.query_type: topics}
def _build_test_qrels(self, dataset_name): items = {} count = 0 if isinstance(dataset_name, str): dataset = ir_datasets.load(dataset_name) else: dataset = dataset_name for i, qrel in enumerate( _logger.pbar(dataset.qrels_iter(), f'{dataset_name} qrels')): count += 1 if i in (0, 9): items[i] = qrel items[count - 1] = qrel _logger.info(f''' self._test_qrels({repr(dataset_name)}, count={count}, items={self._repr_namedtuples(items)}) ''')
def _test_docs(self, dataset_name, count=None, items=None, test_docstore=True, test_iter_split=True): orig_items = dict(items) with self.subTest('docs', dataset=dataset_name): if isinstance(dataset_name, str): dataset = ir_datasets.load(dataset_name) else: dataset = dataset_name expected_count = count items = items or {} count = 0 for i, doc in enumerate( _logger.pbar(dataset.docs_iter(), f'{dataset_name} docs', unit='doc')): count += 1 if i in items: self._assert_namedtuple(doc, items[i]) del items[i] if expected_count is None and len(items) == 0: break # no point in going further if expected_count is not None: self.assertEqual(expected_count, count) self.assertEqual({}, items) if test_iter_split: with self.subTest('docs_iter split', dataset=dataset_name): it = dataset.docs_iter() with _logger.duration('doc lookups by index'): for idx, doc in orig_items.items(): self._assert_namedtuple(next(it[idx:idx + 1]), doc) self._assert_namedtuple(it[idx], doc) if test_docstore: with self.subTest('docs_store', dataset=dataset_name): doc_store = dataset.docs_store() with _logger.duration('doc lookups by doc_id'): for doc in orig_items.values(): ret_doc = doc_store.get(doc.doc_id) self._assert_namedtuple(doc, ret_doc)
def _build_test_scoreddocs(self, dataset_name): items = {} count = 0 if isinstance(dataset_name, str): dataset = ir_datasets.load(dataset_name) else: dataset = dataset_name for i, scoreddoc in enumerate( _logger.pbar(dataset.scoreddocs_iter(), f'{dataset_name} scoreddocs', unit='scoreddoc')): count += 1 if i in (0, 9): items[i] = scoreddoc items[count - 1] = scoreddoc _logger.info(f''' self._test_scoreddocs({repr(dataset_name)}, count={count}, items={self._repr_namedtuples(items)}) ''')
def json(self, run_1_fn, run_2_fn=None): """ Represent the data to be visualized in a json format. The format is specified here: https://github.com/capreolus-ir/diffir-private/issues/5 :params: 2 TREC runs. These dicts of the form {qid: {docid: score}} """ run_1 = load_trec_run(run_1_fn) run_2 = load_trec_run(run_2_fn) if run_2_fn is not None else None dataset = ir_datasets.load(self.dataset) assert dataset.has_docs( ), "dataset has no documents; maybe you're missing a partition like '/trec-dl-2020'" assert dataset.has_queries( ), "dataset has no queries; maybe you're missing a partition like '/trec-dl-2020'" diff_queries, qid2diff, metric_name, qid2qrelscores = self.measure.query_differences( run_1, run_2, dataset=dataset) # _logger.info(diff_queries) diff_query_objects = self.create_query_objects( run_1, run_2, diff_queries, qid2diff, metric_name, dataset, qid2qrelscores=qid2qrelscores) doc_objects = self.create_doc_objects(diff_query_objects, dataset) return json.dumps({ "meta": { "run1_name": run_1_fn, "run2_name": run_2_fn, "dataset": self.dataset, "measure": self.measure.module_name, # "weight": self.weight.module_name, "qrelDefs": dataset.qrels_defs(), "queryFields": dataset.queries_cls()._fields, "docFields": dataset.docs_cls()._fields, "relevanceColors": self.make_rel_colors(dataset), }, "queries": diff_query_objects, "docs": doc_objects, })
def wrapped(): BeautifulSoup = ir_datasets.lazy_libs.bs4().BeautifulSoup # NOTE: These rules are very specific in order to replicate the behaviour present in the official script # here: <https://github.com/grill-lab/trec-cast-tools/blob/8fa243a7e058ce4b1b378c99768c53546460c0fe/src/main/python/wapo_trecweb.py> # Specifically, things like skipping empty documents, filtering by "paragraph" subtype, and starting the # paragraph index at 1 are all needed to perfectly match the above script. # Note that the script does NOT strip HTML markup, which is meant to be removed out in a later stage (e.g., indexing). # We do that here for user simplicity, as it will allow the text to be consumed directly by various models # without the need for further pre-processing. (Though a bit of information is lost.) for wapo_doc in ir_datasets.load( dsid).docs_handler().docs_wapo_raw_iter(): doc_id = wapo_doc['id'] pid = itertools.count(1) # paragrah index starts at 1 for paragraph in wapo_doc['contents']: if paragraph is not None and paragraph.get( 'subtype' ) == 'paragraph' and paragraph['content'] != '': text = paragraph['content'] if paragraph.get('mime') == 'text/html': text = BeautifulSoup(f'<OUTER>{text}</OUTER>', 'lxml-xml').get_text() yield GenericDoc(f'WAPO_{doc_id}-{next(pid)}', text)
def test_clueweb12_docs_html(self): self._test_docs( ir_datasets.wrappers.HtmlDocExtractor( ir_datasets.load('clueweb12')), items={ 0: WarcDoc( 'clueweb12-0000tw-00-00000', 'http://tsawer.net/2012/02/10/france-image-pool-2012-02-10-162252/', '2012-02-10T22:50:41Z', re.compile( b'^HTTP/1\\.1 200 OK\\\r\nDate: Fri, 10 Feb 2012 22:50:40 GMT\\\r\nServer: Apache/2\\.2\\.21 \\(Unix\\) mod_ssl/2\\.2\\.21 Op.{338}ortlink\\\r\nVary: Accept\\-Encoding,User\\-Agent\\\r\nConnection: close\\\r\nContent\\-Type: text/html; charset=UTF\\-8$', flags=16), re.compile( '^\\\r\n\\\t\\\t\\\t France image Pool 2012\\-02\\-10 16:22:52\\\t \n \n \n \n \n rss § \n atom § \n rdf \n \n \n Photos aggregator.{736}essages\\. \n \n \n \n \n \n \n Based on Ocular Professor § Powered by WordPress \n \n \n \n \n \n \n \n \n \n \n \n \n $', flags=48), 'text/plain'), 9: WarcDoc( 'clueweb12-0000tw-00-00009', 'http://claywginn.com/2012/02/10/lessons-learned-from-a-week-on-vacation/', '2012-02-10T21:47:35Z', re.compile( b'^HTTP/1\\.1 200 OK\\\r\nDate: Fri, 10 Feb 2012 21:47:36 GMT\\\r\nServer: Apache\\\r\nX\\-Powered\\-By: PHP/5\\.2\\.17\\\r\nX\\-Pi.{45}: <http://wp\\.me/p1zQki\\-AT>; rel=shortlink\\\r\nConnection: close\\\r\nContent\\-Type: text/html; charset=UTF\\-8$', flags=16), re.compile( '^Lessons learned from a week on vacation \\| claywginn\\.com \n \n \n \n \n Home \n About me \n Contact me \n \n.{5287} Words Posts: 21,458 Words \\(511 Avg\\.\\) \n Powered by WordPress \\| Designed by Elegant Themes \n \n $', flags=48), 'text/plain'), 1000: WarcDoc( 'clueweb12-0000tw-00-01002', 'http://beanpotscastiron.waffleironshapes.com/le-creuset-enameled-cast-iron-7-14-quart-round-french-oven-cherry-red-save-price-shopping-online/', '2012-02-10T21:55:43Z', re.compile( b'^HTTP/1\\.1 200 OK\\\r\nDate: Fri, 10 Feb 2012 21:55:42 GMT\\\r\nServer: Apache\\\r\nX\\-Pingback: http://beanpotscas.{70}waffleironshapes\\.com/\\?p=5>; rel=shortlink\\\r\nConnection: close\\\r\nContent\\-Type: text/html; charset=UTF\\-8$', flags=16), re.compile( '^Le Creuset Enameled Cast\\-Iron 7\\-1/4\\-Quart Round French Oven, Cherry Red Save Price Shopping Online \\|.{4936}sites to earn advertising fees by advertising and linking to amazon\\.com Web Toolbar by Wibiya \n \n \n $', flags=48), 'text/plain'), })
def _test_scoreddocs(self, dataset_name, count=None, items=None): with self.subTest('scoreddocs', dataset=dataset_name): if isinstance(dataset_name, str): dataset = ir_datasets.load(dataset_name) else: dataset = dataset_name expected_count = count items = items or {} count = 0 for i, scoreddoc in enumerate( _logger.pbar(dataset.scoreddocs_iter(), f'{dataset_name} scoreddocs', unit='scoreddoc')): count += 1 if i in items: self._assert_namedtuple(scoreddoc, items[i]) del items[i] if expected_count is None and len(items) == 0: break # no point in going further if expected_count is not None: self.assertEqual(expected_count, count) self.assertEqual(0, len(items))
def _get_qrels(args): # gets the qrels, either from a file (priority) or from ir_datasets (if installed) if os.path.exists(args.qrels): return ir_measures.read_trec_qrels(args.qrels) irds_available = False try: import ir_datasets irds_available = True except ImportError: sys.stderr.write(f'Skipping ir_datasets lookup. To use this feature, install ir_datasets.\n') if irds_available: try: ds = ir_datasets.load(args.qrels) if ds.has_qrels(): return ds.qrels_iter() sys.stderr.write(f'ir_datasets ID {args.qrels} found but does not provide qrels.\n') sys.exit(-1) except KeyError: sys.stderr.write(f'{args.qrels} not found. (checked file and ir_datasets)\n') sys.exit(-1) sys.stderr.write(f'{args.qrels} not found.\n') sys.exit(-1)
def dataset2metadata(args): dsid, data = args try: dataset = ir_datasets.load(dsid) except KeyError: return dsid, None try: for e in ir_datasets.EntityType: if dataset.has(e): if e.value not in data: parent_id = getattr(ir_datasets, f'{e.value}_parent_id')(dsid) if parent_id != dsid: data[e.value] = {'_ref': parent_id} else: with _logger.duration(f'{dsid} {e.value}'): data[e.value] = getattr( dataset, f'{e.value}_calc_metadata')() _logger.info(f'{dsid} {e.value}: {data[e.value]}') except Exception as ex: _logger.info(f'{dsid} {e.value} [error]: {ex}') return dsid, None return dsid, data
return 'en' DL_ANSERINI_ROBUST04 = ir_datasets.util.Download( [ ir_datasets.util.RequestsDownload( 'https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-robust04-20191213.tar.gz' ) ], expected_md5='15f3d001489c97849a010b0a4734d018') DL_ANSERINI_ROBUST04 = Cache( TarExtract(DL_ANSERINI_ROBUST04, 'index-robust04-20191213/_h.fdt'), base_path / 'lucene_source.fdt') collection = AnseriniRobustDocs(DL_ANSERINI_ROBUST04) for ds_name in [ 'trec-robust04', 'trec-robust04/fold1', 'trec-robust04/fold2', 'trec-robust04/fold3', 'trec-robust04/fold4', 'trec-robust04/fold5' ]: main_ds = ir_datasets.load(ds_name) dataset = ir_datasets.Dataset( collection, main_ds.queries_handler(), main_ds.qrels_handler(), ) # Register the dataset with ir_datasets ir_datasets.registry.register(PREFIX + ds_name, dataset)
def __init__(self, config, logger, vocab): super().__init__(config, logger, vocab) if config['ds']: ds = ir_datasets.load(config['ds']) if not config['docs_ds']: # HACK: find "parent" dataset that contains same docs handler so we don't re-build the index for the same collection segments = config['ds'].split('/') docs_handler = ds.docs_handler() parent_docs_ds = config['ds'] while len(segments) > 1: segments = segments[:-1] parent_ds = ir_datasets.load('/'.join(segments)) if parent_ds.has_docs() and parent_ds.docs_handler( ) == docs_handler: parent_docs_ds = '/'.join(segments) config['docs_ds'] = parent_docs_ds if not config['queries_ds']: config['queries_ds'] = config['ds'] if config['doc_fields']: if not config['docs_index_fields']: config['docs_index_fields'] = config['doc_fields'] if not config['docs_rerank_fields']: config['docs_rerank_fields'] = config['doc_fields'] if config['query_fields']: if not config['queries_index_fields']: config['queries_index_fields'] = config['query_fields'] if not config['queries_rerank_fields']: config['queries_rerank_fields'] = config['query_fields'] self.docs_ds = ir_datasets.load(config['docs_ds']) self.queries_ds = ir_datasets.load(config['queries_ds']) assert self.docs_ds.has_docs() assert self.queries_ds.has_queries() if not config['docs_index_fields']: config['docs_index_fields'] = ','.join( self.docs_ds.docs_cls()._fields[1:]) self.logger.info( 'auto-filled docs_index_fields as {docs_index_fields}'.format( **config)) if not config['docs_rerank_fields']: config['docs_rerank_fields'] = ','.join( self.docs_ds.docs_cls()._fields[1:]) self.logger.info( 'auto-filled docs_rerank_fields as {docs_rerank_fields}'. format(**config)) if not config['queries_index_fields']: config['queries_index_fields'] = ','.join( self.queries_ds.queries_cls()._fields[1:]) self.logger.info( 'auto-filled queries_index_fields as {queries_index_fields}'. format(**config)) if not config['queries_rerank_fields']: config['queries_rerank_fields'] = ','.join( self.queries_ds.queries_cls()._fields[1:]) self.logger.info( 'auto-filled queries_rerank_fields as {queries_rerank_fields}'. format(**config)) base_path = os.path.join(util.path_dataset(self), sanitize_path(self.config['docs_ds'])) os.makedirs(base_path, exist_ok=True) real_anserini_path = os.path.join( base_path, 'anserini.porter.{docs_index_fields}'.format(**self.config)) os.makedirs(real_anserini_path, exist_ok=True) virtual_anserini_path = '{}.{}'.format( real_anserini_path, sanitize_path(config['queries_ds'])) if not os.path.exists(virtual_anserini_path): os.symlink(real_anserini_path, virtual_anserini_path, target_is_directory=True) self.index = indices.AnseriniIndex(virtual_anserini_path, stemmer='porter') self.doc_store = indices.IrdsDocstore(self.docs_ds.docs_store(), config['docs_rerank_fields'])