def batch_query(self, queries, model, topk, destf=None, quiet=False): THREADS = onir.util.safe_thread_count() query_file_splits = 1000 if hasattr(queries, '__len__'): if len(queries) < THREADS: THREADS = len(queries) query_file_splits = 1 elif len(queries) < THREADS * 10: query_file_splits = ((len(queries) + 1) // THREADS) elif len(queries) < THREADS * 100: query_file_splits = ((len(queries) + 1) // (THREADS * 10)) else: query_file_splits = ((len(queries) + 1) // (THREADS * 100)) with tempfile.TemporaryDirectory( ) as topic_d, tempfile.TemporaryDirectory() as run_d: run_f = os.path.join(run_d, 'run') topic_files = [] file_topic_counts = [] current_file = None total_topics = 0 for i, (qid, text) in enumerate(queries): topic_file = '{}/{}.queries'.format(topic_d, i // query_file_splits) if current_file is None or current_file.name != topic_file: if current_file is not None: topic_files.append(current_file.name) current_file.close() current_file = open(topic_file, 'wt') file_topic_counts.append(0) current_file.write(f'{qid}\t{text}\n') file_topic_counts[-1] += 1 total_topics += 1 if current_file is not None: topic_files.append(current_file.name) current_file.close() J.initialize() with ThreadPool(THREADS) as pool, \ logger.pbar_raw(desc=f'batch_query ({model})', total=total_topics) as pbar: def fn(inputs): file, count = inputs args = J.A_SearchArgs() parser = J.M_CmdLineParser(args) arg_args = [ '-index', self._path, '-topics', file, '-output', file + '.run', '-topicreader', 'TsvString', '-hits', str(topk), '-stemmer', self._settings['stemmer'], '-indexfield', self._primary_field, ] arg_args += self._model2args(model) parser.parseArgument(*arg_args) searcher = J.A_SearchCollection(args) searcher.runTopics() searcher.close() return file + '.run', count if destf: result = open(destf + '.tmp', 'wb') else: result = {} for resultf, count in pool.imap_unordered( fn, zip(topic_files, file_topic_counts)): if destf: with open(resultf, 'rb') as f: for line in f: result.write(line) else: run = trec.read_run_dict(resultf) result.update(run) pbar.update(count) if destf: result.close() shutil.move(destf + '.tmp', destf) else: return result
def __call__(self, ctxt): cached = True epoch = ctxt['epoch'] base_path = os.path.join(ctxt['base_path'], self.pred.dataset.path_segment()) if self.pred.config[ 'source'] == 'run' and self.pred.config['run_threshold'] > 0: base_path = '{p}_runthreshold-{run_threshold}'.format( p=base_path, **self.pred.config) os.makedirs(os.path.join(base_path, 'runs'), exist_ok=True) with open(os.path.join(base_path, 'config.json'), 'wt') as f: json.dump(self.pred.dataset.config, f) run_path = os.path.join(base_path, 'runs', f'{epoch}.run') if os.path.exists(run_path): run = trec.read_run_dict(run_path) else: if self.pred.config['source'] == 'run' and self.pred.config[ 'run_threshold'] > 0: official_run = self.pred.dataset.run('dict') else: official_run = {} run = {} ranker = ctxt['ranker']().to(self.device) this_qid = None these_docs = {} with util.finialized_file(run_path, 'wt') as f: for qid, did, score in self.pred.iter_scores( ranker, self.datasource, self.device): if qid != this_qid: if this_qid is not None: these_docs = self._apply_threshold( these_docs, official_run.get(this_qid, {})) trec.write_run_dict(f, {this_qid: these_docs}) this_qid = qid these_docs = {} these_docs[did] = score if this_qid is not None: these_docs = self._apply_threshold( these_docs, official_run.get(this_qid, {})) trec.write_run_dict(f, {this_qid: these_docs}) cached = False result = { 'epoch': epoch, 'run': run, 'run_path': run_path, 'base_path': base_path, 'cached': cached } result['metrics'] = { m: None for m in self.pred.config['measures'].split(',') if m } result['metrics_by_query'] = {m: None for m in result['metrics']} missing_metrics = self.load_metrics(result) if missing_metrics: measures = set(missing_metrics) result['cached'] = False qrels = self.pred.dataset.qrels() calculated_metrics = onir.metrics.calc(qrels, run_path, measures) result['metrics_by_query'].update(calculated_metrics) result['metrics'].update(onir.metrics.mean(calculated_metrics)) self.write_missing_metrics(result, missing_metrics) try: if ctxt['ranker']().config.get('add_runscore'): result['metrics']['runscore_alpha'] = torch.sigmoid( ctxt['ranker']().runscore_alpha).item() rs_alpha_f = os.path.join(ctxt['base_path'], 'runscore_alpha.txt') with open(rs_alpha_f, 'at') as f: plaintext.write_tsv(rs_alpha_f, [ (str(epoch), str(result['metrics']['runscore_alpha'])) ]) except FileNotFoundError: pass # model may no longer exist, ignore return result
def batch_query(self, queries, model, topk, destf=None, quiet=False): THREADS = onir.util.safe_thread_count() query_file_splits = 1000 if hasattr(queries, '__len__'): if len(queries) < THREADS: THREADS = len(queries) query_file_splits = 1 elif len(queries) < THREADS * 10: query_file_splits = ((len(queries)+1) // THREADS) elif len(queries) < THREADS * 100: query_file_splits = ((len(queries)+1) // (THREADS * 10)) else: query_file_splits = ((len(queries)+1) // (THREADS * 100)) with tempfile.TemporaryDirectory() as topic_d, tempfile.TemporaryDirectory() as run_d: run_f = os.path.join(run_d, 'run') topic_files = [] current_file = None total_topics = 0 for i, (qid, text) in enumerate(queries): topic_file = '{}/{}.queries'.format(topic_d, i // query_file_splits) if current_file is None or current_file.name != topic_file: if current_file is not None: topic_files.append(current_file.name) current_file.close() current_file = open(topic_file, 'wt') current_file.write(f'{qid}\t{text}\n') total_topics += 1 if current_file is not None: topic_files.append(current_file.name) current_file.close() args = J.A_SearchArgs() parser = J.M_CmdLineParser(args) arg_args = [ '-index', self._path, '-topics', *topic_files, '-output', run_f, '-topicreader', 'TsvString', '-threads', str(THREADS), '-hits', str(topk), '-language', self._settings['lang'], ] if model.startswith('bm25'): arg_args.append('-bm25') model_args = [arg.split('-', 1) for arg in model.split('_')[1:]] for arg in model_args: if len(arg) == 1: k, v = arg[0], None elif len(arg) == 2: k, v = arg if k == 'k1': arg_args.append('-bm25.k1') arg_args.append(v) elif k == 'b': arg_args.append('-bm25.b') arg_args.append(v) elif k == 'rm3': arg_args.append('-rm3') elif k == 'rm3.fbTerms': arg_args.append('-rm3.fbTerms') arg_args.append(v) elif k == 'rm3.fbDocs': arg_args.append('-rm3.fbDocs') arg_args.append(v) else: raise ValueError(f'unknown bm25 parameter {arg}') elif model.startswith('ql'): arg_args.append('-qld') model_args = [arg.split('-', 1) for arg in model.split('_')[1:]] for arg in model_args: if len(arg) == 1: k, v = arg[0], None elif len(arg) == 2: k, v = arg if k == 'mu': arg_args.append('-qld.mu') arg_args.append(v) else: raise ValueError(f'unknown ql parameter {arg}') elif model.startswith('sdm'): arg_args.append('-sdm') arg_args.append('-qld') model_args = [arg.split('-', 1) for arg in model.split('_')[1:]] for arg in model_args: if len(arg) == 1: k, v = arg[0], None elif len(arg) == 2: k, v = arg if k == 'mu': arg_args.append('-qld.mu') arg_args.append(v) elif k == 'tw': arg_args.append('-sdm.tw') arg_args.append(v) elif k == 'ow': arg_args.append('-sdm.ow') arg_args.append(v) elif k == 'uw': arg_args.append('-sdm.uw') arg_args.append(v) else: raise ValueError(f'unknown sdm parameter {arg}') else: raise ValueError(f'unknown model {model}') parser.parseArgument(*arg_args) with contextlib.ExitStack() as stack: stack.enter_context(J.listen_java_log(_surpress_log('io.anserini.search.SearchCollection'))) if not quiet: pbar = stack.enter_context(logger.pbar_raw(desc=f'batch_query ({model})', total=total_topics)) stack.enter_context(J.listen_java_log(pbar_bq_listener(pbar))) searcher = J.A_SearchCollection(args) searcher.runTopics() searcher.close() if destf: shutil.copy(run_f, destf) else: return trec.read_run_dict(run_f)
def init(self, force=False): idxs = [self.index_stem, self.doc_store] self._init_indices_parallel(idxs, self._init_iter_collection(), force) base_path = util.path_dataset(self) needs_queries = [] if force or not os.path.exists( os.path.join(base_path, 'train.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'train.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.train.tsv' and qid not in MINI_DEV))) if force or not os.path.exists( os.path.join(base_path, 'minidev.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'minidev.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.train.tsv' and qid in MINI_DEV))) if force or not os.path.exists( os.path.join(base_path, 'dev.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'dev.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.dev.tsv'))) if force or not os.path.exists( os.path.join(base_path, 'eval.queries.tsv')): needs_queries.append(lambda it: plaintext.write_tsv( os.path.join(base_path, 'eval.queries.tsv'), ((qid, txt) for file, qid, txt in it if file == 'queries.eval.tsv'))) if needs_queries and self._confirm_dua(): with util.download_tmp(_SOURCES['queries']) as f, \ tarfile.open(fileobj=f) as tarf, \ contextlib.ExitStack() as ctxt: def _extr_subf(subf): for qid, txt in plaintext.read_tsv( io.TextIOWrapper(tarf.extractfile(subf))): yield subf, qid, txt query_iter = [ _extr_subf('queries.train.tsv'), _extr_subf('queries.dev.tsv'), _extr_subf('queries.eval.tsv') ] query_iter = tqdm(itertools.chain(*query_iter), desc='queries') query_iters = util.blocking_tee(query_iter, len(needs_queries)) for fn, it in zip(needs_queries, query_iters): ctxt.enter_context( util.CtxtThread(functools.partial(fn, it))) file = os.path.join(base_path, 'train.qrels') if (force or not os.path.exists(file)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['train-qrels'], 'utf8') with util.finialized_file(file, 'wt') as out: for qid, _, did, score in plaintext.read_tsv(stream): if qid not in MINI_DEV: trec.write_qrels(out, [(qid, did, score)]) file = os.path.join(base_path, 'minidev.qrels') if (force or not os.path.exists(file)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['train-qrels'], 'utf8') with util.finialized_file(file, 'wt') as out: for qid, _, did, score in plaintext.read_tsv(stream): if qid in MINI_DEV: trec.write_qrels(out, [(qid, did, score)]) file = os.path.join(base_path, 'dev.qrels') if (force or not os.path.exists(file)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['dev-qrels'], 'utf8') with util.finialized_file(file, 'wt') as out: for qid, _, did, score in plaintext.read_tsv(stream): trec.write_qrels(out, [(qid, did, score)]) file = os.path.join(base_path, 'train.mspairs.gz') if not os.path.exists(file) and os.path.exists( os.path.join(base_path, 'qidpidtriples.train.full')): # legacy os.rename(os.path.join(base_path, 'qidpidtriples.train.full'), file) if (force or not os.path.exists(file)) and self._confirm_dua(): util.download(_SOURCES['qidpidtriples.train.full'], file) if not self.config['init_skip_msrun']: for file_name, subf in [('dev.msrun', 'top1000.dev'), ('eval.msrun', 'top1000.eval'), ('train.msrun', 'top1000.train.txt')]: file = os.path.join(base_path, file_name) if (force or not os.path.exists(file)) and self._confirm_dua(): run = {} with util.download_tmp(_SOURCES[file_name]) as f, \ tarfile.open(fileobj=f) as tarf: for qid, did, _, _ in tqdm( plaintext.read_tsv( io.TextIOWrapper(tarf.extractfile(subf)))): if qid not in run: run[qid] = {} run[qid][did] = 0. if file_name == 'train.msrun': minidev = { qid: dids for qid, dids in run.items() if qid in MINI_DEV } with self.logger.duration('writing minidev.msrun'): trec.write_run_dict( os.path.join(base_path, 'minidev.msrun'), minidev) run = { qid: dids for qid, dids in run.items() if qid not in MINI_DEV } with self.logger.duration(f'writing {file_name}'): trec.write_run_dict(file, run) query_path = os.path.join(base_path, 'trec2019.queries.tsv') if (force or not os.path.exists(query_path)) and self._confirm_dua(): stream = util.download_stream(_SOURCES['trec2019.queries'], 'utf8') plaintext.write_tsv(query_path, plaintext.read_tsv(stream)) msrun_path = os.path.join(base_path, 'trec2019.msrun') if (force or not os.path.exists(msrun_path)) and self._confirm_dua(): run = {} with util.download_stream(_SOURCES['trec2019.msrun'], 'utf8') as stream: for qid, did, _, _ in plaintext.read_tsv(stream): if qid not in run: run[qid] = {} run[qid][did] = 0. with util.finialized_file(msrun_path, 'wt') as f: trec.write_run_dict(f, run) qrels_path = os.path.join(base_path, 'trec2019.qrels') if not os.path.exists(qrels_path) and self._confirm_dua(): util.download(_SOURCES['trec2019.qrels'], qrels_path) qrels_path = os.path.join(base_path, 'judgedtrec2019.qrels') if not os.path.exists(qrels_path): os.symlink('trec2019.qrels', qrels_path) query_path = os.path.join(base_path, 'judgedtrec2019.queries.tsv') judged_qids = util.Lazy( lambda: trec.read_qrels_dict(qrels_path).keys()) if (force or not os.path.exists(query_path)): with util.finialized_file(query_path, 'wt') as f: for qid, qtext in plaintext.read_tsv( os.path.join(base_path, 'trec2019.queries.tsv')): if qid in judged_qids(): plaintext.write_tsv(f, [(qid, qtext)]) msrun_path = os.path.join(base_path, 'judgedtrec2019.msrun') if (force or not os.path.exists(msrun_path)) and self._confirm_dua(): with util.finialized_file(msrun_path, 'wt') as f: for qid, dids in trec.read_run_dict( os.path.join(base_path, 'trec2019.msrun')).items(): if qid in judged_qids(): trec.write_run_dict(f, {qid: dids}) # A subset of dev that only contains queries that have relevance judgments judgeddev_path = os.path.join(base_path, 'judgeddev') judged_qids = util.Lazy(lambda: trec.read_qrels_dict( os.path.join(base_path, 'dev.qrels')).keys()) if not os.path.exists(f'{judgeddev_path}.qrels'): os.symlink('dev.qrels', f'{judgeddev_path}.qrels') if not os.path.exists(f'{judgeddev_path}.queries.tsv'): with util.finialized_file(f'{judgeddev_path}.queries.tsv', 'wt') as f: for qid, qtext in plaintext.read_tsv( os.path.join(base_path, 'dev.queries.tsv')): if qid in judged_qids(): plaintext.write_tsv(f, [(qid, qtext)]) if self.config['init_skip_msrun']: if not os.path.exists(f'{judgeddev_path}.msrun'): with util.finialized_file(f'{judgeddev_path}.msrun', 'wt') as f: for qid, dids in trec.read_run_dict( os.path.join(base_path, 'dev.msrun')).items(): if qid in judged_qids(): trec.write_run_dict(f, {qid: dids}) if not self.config['init_skip_train10']: file = os.path.join(base_path, 'train10.queries.tsv') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout: for qid, qtext in self.logger.pbar( plaintext.read_tsv( os.path.join(base_path, 'train.queries.tsv')), desc='filtering queries for train10'): if int(qid) % 10 == 0: plaintext.write_tsv(fout, [(qid, qtext)]) file = os.path.join(base_path, 'train10.qrels') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.qrels'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering qrels for train10'): qid = line.split()[0] if int(qid) % 10 == 0: fout.write(line) if not self.config['init_skip_msrun']: file = os.path.join(base_path, 'train10.msrun') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.msrun'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering msrun for train10'): qid = line.split()[0] if int(qid) % 10 == 0: fout.write(line) file = os.path.join(base_path, 'train10.mspairs.gz') if not os.path.exists(file): with gzip.open(file, 'wt') as fout, gzip.open( os.path.join(base_path, 'train.mspairs.gz'), 'rt') as fin: for qid, did1, did2 in self.logger.pbar( plaintext.read_tsv(fin), desc='filtering mspairs for train10'): if int(qid) % 10 == 0: plaintext.write_tsv(fout, [(qid, did1, did2)]) if not self.config['init_skip_train_med']: med_qids = util.Lazy( lambda: { qid.strip() for qid in util.download_stream( 'https://raw.githubusercontent.com/Georgetown-IR-Lab/covid-neural-ir/master/med-msmarco-train.txt', 'utf8', expected_md5="dc5199de7d4a872c361f89f08b1163ef") }) file = os.path.join(base_path, 'train_med.queries.tsv') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout: for qid, qtext in self.logger.pbar( plaintext.read_tsv( os.path.join(base_path, 'train.queries.tsv')), desc='filtering queries for train_med'): if qid in med_qids(): plaintext.write_tsv(fout, [(qid, qtext)]) file = os.path.join(base_path, 'train_med.qrels') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.qrels'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering qrels for train_med'): qid = line.split()[0] if qid in med_qids(): fout.write(line) if not self.config['init_skip_msrun']: file = os.path.join(base_path, 'train_med.msrun') if not os.path.exists(file): with util.finialized_file(file, 'wt') as fout, open( os.path.join(base_path, 'train.msrun'), 'rt') as fin: for line in self.logger.pbar( fin, desc='filtering msrun for train_med'): qid = line.split()[0] if qid in med_qids(): fout.write(line) file = os.path.join(base_path, 'train_med.mspairs.gz') if not os.path.exists(file): with gzip.open(file, 'wt') as fout, gzip.open( os.path.join(base_path, 'train.mspairs.gz'), 'rt') as fin: for qid, did1, did2 in self.logger.pbar( plaintext.read_tsv(fin), desc='filtering mspairs for train_med'): if qid in med_qids(): plaintext.write_tsv(fout, [(qid, did1, did2)])