def fn(inputs): file, count = inputs args = J.A_SearchArgs() parser = J.M_CmdLineParser(args) arg_args = [ '-index', self._path, '-topics', file, '-output', file + '.run', '-topicreader', 'TsvString', '-hits', str(topk), '-stemmer', self._settings['stemmer'], '-indexfield', self._primary_field, ] arg_args += self._model2args(model) parser.parseArgument(*arg_args) searcher = J.A_SearchCollection(args) searcher.runTopics() searcher.close() return file + '.run', count
def batch_query(self, queries, model, topk, destf=None, quiet=False): THREADS = onir.util.safe_thread_count() query_file_splits = 1000 if hasattr(queries, '__len__'): if len(queries) < THREADS: THREADS = len(queries) query_file_splits = 1 elif len(queries) < THREADS * 10: query_file_splits = ((len(queries)+1) // THREADS) elif len(queries) < THREADS * 100: query_file_splits = ((len(queries)+1) // (THREADS * 10)) else: query_file_splits = ((len(queries)+1) // (THREADS * 100)) with tempfile.TemporaryDirectory() as topic_d, tempfile.TemporaryDirectory() as run_d: run_f = os.path.join(run_d, 'run') topic_files = [] current_file = None total_topics = 0 for i, (qid, text) in enumerate(queries): topic_file = '{}/{}.queries'.format(topic_d, i // query_file_splits) if current_file is None or current_file.name != topic_file: if current_file is not None: topic_files.append(current_file.name) current_file.close() current_file = open(topic_file, 'wt') current_file.write(f'{qid}\t{text}\n') total_topics += 1 if current_file is not None: topic_files.append(current_file.name) current_file.close() args = J.A_SearchArgs() parser = J.M_CmdLineParser(args) arg_args = [ '-index', self._path, '-topics', *topic_files, '-output', run_f, '-topicreader', 'TsvString', '-threads', str(THREADS), '-hits', str(topk), '-language', self._settings['lang'], ] if model.startswith('bm25'): arg_args.append('-bm25') model_args = [arg.split('-', 1) for arg in model.split('_')[1:]] for arg in model_args: if len(arg) == 1: k, v = arg[0], None elif len(arg) == 2: k, v = arg if k == 'k1': arg_args.append('-bm25.k1') arg_args.append(v) elif k == 'b': arg_args.append('-bm25.b') arg_args.append(v) elif k == 'rm3': arg_args.append('-rm3') elif k == 'rm3.fbTerms': arg_args.append('-rm3.fbTerms') arg_args.append(v) elif k == 'rm3.fbDocs': arg_args.append('-rm3.fbDocs') arg_args.append(v) else: raise ValueError(f'unknown bm25 parameter {arg}') elif model.startswith('ql'): arg_args.append('-qld') model_args = [arg.split('-', 1) for arg in model.split('_')[1:]] for arg in model_args: if len(arg) == 1: k, v = arg[0], None elif len(arg) == 2: k, v = arg if k == 'mu': arg_args.append('-qld.mu') arg_args.append(v) else: raise ValueError(f'unknown ql parameter {arg}') elif model.startswith('sdm'): arg_args.append('-sdm') arg_args.append('-qld') model_args = [arg.split('-', 1) for arg in model.split('_')[1:]] for arg in model_args: if len(arg) == 1: k, v = arg[0], None elif len(arg) == 2: k, v = arg if k == 'mu': arg_args.append('-qld.mu') arg_args.append(v) elif k == 'tw': arg_args.append('-sdm.tw') arg_args.append(v) elif k == 'ow': arg_args.append('-sdm.ow') arg_args.append(v) elif k == 'uw': arg_args.append('-sdm.uw') arg_args.append(v) else: raise ValueError(f'unknown sdm parameter {arg}') else: raise ValueError(f'unknown model {model}') parser.parseArgument(*arg_args) with contextlib.ExitStack() as stack: stack.enter_context(J.listen_java_log(_surpress_log('io.anserini.search.SearchCollection'))) if not quiet: pbar = stack.enter_context(logger.pbar_raw(desc=f'batch_query ({model})', total=total_topics)) stack.enter_context(J.listen_java_log(pbar_bq_listener(pbar))) searcher = J.A_SearchCollection(args) searcher.runTopics() searcher.close() if destf: shutil.copy(run_f, destf) else: return trec.read_run_dict(run_f)