Exemplo n.º 1
0
    def batch_query(self, queries, model, topk, destf=None, quiet=False):
        THREADS = onir.util.safe_thread_count()
        query_file_splits = 1000
        if hasattr(queries, '__len__'):
            if len(queries) < THREADS:
                THREADS = len(queries)
                query_file_splits = 1
            elif len(queries) < THREADS * 10:
                query_file_splits = ((len(queries) + 1) // THREADS)
            elif len(queries) < THREADS * 100:
                query_file_splits = ((len(queries) + 1) // (THREADS * 10))
            else:
                query_file_splits = ((len(queries) + 1) // (THREADS * 100))
        with tempfile.TemporaryDirectory(
        ) as topic_d, tempfile.TemporaryDirectory() as run_d:
            run_f = os.path.join(run_d, 'run')
            topic_files = []
            file_topic_counts = []
            current_file = None
            total_topics = 0
            for i, (qid, text) in enumerate(queries):
                topic_file = '{}/{}.queries'.format(topic_d,
                                                    i // query_file_splits)
                if current_file is None or current_file.name != topic_file:
                    if current_file is not None:
                        topic_files.append(current_file.name)
                        current_file.close()
                    current_file = open(topic_file, 'wt')
                    file_topic_counts.append(0)
                current_file.write(f'{qid}\t{text}\n')
                file_topic_counts[-1] += 1
                total_topics += 1
            if current_file is not None:
                topic_files.append(current_file.name)
                current_file.close()
            J.initialize()
            with ThreadPool(THREADS) as pool, \
                 logger.pbar_raw(desc=f'batch_query ({model})', total=total_topics) as pbar:

                def fn(inputs):
                    file, count = inputs
                    args = J.A_SearchArgs()
                    parser = J.M_CmdLineParser(args)
                    arg_args = [
                        '-index',
                        self._path,
                        '-topics',
                        file,
                        '-output',
                        file + '.run',
                        '-topicreader',
                        'TsvString',
                        '-hits',
                        str(topk),
                        '-stemmer',
                        self._settings['stemmer'],
                        '-indexfield',
                        self._primary_field,
                    ]
                    arg_args += self._model2args(model)
                    parser.parseArgument(*arg_args)
                    searcher = J.A_SearchCollection(args)
                    searcher.runTopics()
                    searcher.close()
                    return file + '.run', count

                if destf:
                    result = open(destf + '.tmp', 'wb')
                else:
                    result = {}
                for resultf, count in pool.imap_unordered(
                        fn, zip(topic_files, file_topic_counts)):
                    if destf:
                        with open(resultf, 'rb') as f:
                            for line in f:
                                result.write(line)
                    else:
                        run = trec.read_run_dict(resultf)
                        result.update(run)
                    pbar.update(count)
                if destf:
                    result.close()
                    shutil.move(destf + '.tmp', destf)
                else:
                    return result
Exemplo n.º 2
0
    def __call__(self, ctxt):
        cached = True
        epoch = ctxt['epoch']
        base_path = os.path.join(ctxt['base_path'],
                                 self.pred.dataset.path_segment())
        if self.pred.config[
                'source'] == 'run' and self.pred.config['run_threshold'] > 0:
            base_path = '{p}_runthreshold-{run_threshold}'.format(
                p=base_path, **self.pred.config)
        os.makedirs(os.path.join(base_path, 'runs'), exist_ok=True)
        with open(os.path.join(base_path, 'config.json'), 'wt') as f:
            json.dump(self.pred.dataset.config, f)
        run_path = os.path.join(base_path, 'runs', f'{epoch}.run')
        if os.path.exists(run_path):
            run = trec.read_run_dict(run_path)
        else:
            if self.pred.config['source'] == 'run' and self.pred.config[
                    'run_threshold'] > 0:
                official_run = self.pred.dataset.run('dict')
            else:
                official_run = {}
            run = {}
            ranker = ctxt['ranker']().to(self.device)
            this_qid = None
            these_docs = {}
            with util.finialized_file(run_path, 'wt') as f:
                for qid, did, score in self.pred.iter_scores(
                        ranker, self.datasource, self.device):
                    if qid != this_qid:
                        if this_qid is not None:
                            these_docs = self._apply_threshold(
                                these_docs, official_run.get(this_qid, {}))
                            trec.write_run_dict(f, {this_qid: these_docs})
                        this_qid = qid
                        these_docs = {}
                    these_docs[did] = score
                if this_qid is not None:
                    these_docs = self._apply_threshold(
                        these_docs, official_run.get(this_qid, {}))
                    trec.write_run_dict(f, {this_qid: these_docs})
            cached = False

        result = {
            'epoch': epoch,
            'run': run,
            'run_path': run_path,
            'base_path': base_path,
            'cached': cached
        }

        result['metrics'] = {
            m: None
            for m in self.pred.config['measures'].split(',') if m
        }
        result['metrics_by_query'] = {m: None for m in result['metrics']}

        missing_metrics = self.load_metrics(result)

        if missing_metrics:
            measures = set(missing_metrics)
            result['cached'] = False
            qrels = self.pred.dataset.qrels()
            calculated_metrics = onir.metrics.calc(qrels, run_path, measures)
            result['metrics_by_query'].update(calculated_metrics)
            result['metrics'].update(onir.metrics.mean(calculated_metrics))
            self.write_missing_metrics(result, missing_metrics)

        try:
            if ctxt['ranker']().config.get('add_runscore'):
                result['metrics']['runscore_alpha'] = torch.sigmoid(
                    ctxt['ranker']().runscore_alpha).item()
                rs_alpha_f = os.path.join(ctxt['base_path'],
                                          'runscore_alpha.txt')
                with open(rs_alpha_f, 'at') as f:
                    plaintext.write_tsv(rs_alpha_f, [
                        (str(epoch), str(result['metrics']['runscore_alpha']))
                    ])
        except FileNotFoundError:
            pass  # model may no longer exist, ignore

        return result
Exemplo n.º 3
0
 def batch_query(self, queries, model, topk, destf=None, quiet=False):
     THREADS = onir.util.safe_thread_count()
     query_file_splits = 1000
     if hasattr(queries, '__len__'):
         if len(queries) < THREADS:
             THREADS = len(queries)
             query_file_splits = 1
         elif len(queries) < THREADS * 10:
             query_file_splits = ((len(queries)+1) // THREADS)
         elif len(queries) < THREADS * 100:
             query_file_splits = ((len(queries)+1) // (THREADS * 10))
         else:
             query_file_splits = ((len(queries)+1) // (THREADS * 100))
     with tempfile.TemporaryDirectory() as topic_d, tempfile.TemporaryDirectory() as run_d:
         run_f = os.path.join(run_d, 'run')
         topic_files = []
         current_file = None
         total_topics = 0
         for i, (qid, text) in enumerate(queries):
             topic_file = '{}/{}.queries'.format(topic_d, i // query_file_splits)
             if current_file is None or current_file.name != topic_file:
                 if current_file is not None:
                     topic_files.append(current_file.name)
                     current_file.close()
                 current_file = open(topic_file, 'wt')
             current_file.write(f'{qid}\t{text}\n')
             total_topics += 1
         if current_file is not None:
             topic_files.append(current_file.name)
         current_file.close()
         args = J.A_SearchArgs()
         parser = J.M_CmdLineParser(args)
         arg_args = [
             '-index', self._path,
             '-topics', *topic_files,
             '-output', run_f,
             '-topicreader', 'TsvString',
             '-threads', str(THREADS),
             '-hits', str(topk),
             '-language', self._settings['lang'],
         ]
         if model.startswith('bm25'):
             arg_args.append('-bm25')
             model_args = [arg.split('-', 1) for arg in model.split('_')[1:]]
             for arg in model_args:
                 if len(arg) == 1:
                     k, v = arg[0], None
                 elif len(arg) == 2:
                     k, v = arg
                 if k == 'k1':
                     arg_args.append('-bm25.k1')
                     arg_args.append(v)
                 elif k == 'b':
                     arg_args.append('-bm25.b')
                     arg_args.append(v)
                 elif k == 'rm3':
                     arg_args.append('-rm3')
                 elif k == 'rm3.fbTerms':
                     arg_args.append('-rm3.fbTerms')
                     arg_args.append(v)
                 elif k == 'rm3.fbDocs':
                     arg_args.append('-rm3.fbDocs')
                     arg_args.append(v)
                 else:
                     raise ValueError(f'unknown bm25 parameter {arg}')
         elif model.startswith('ql'):
             arg_args.append('-qld')
             model_args = [arg.split('-', 1) for arg in model.split('_')[1:]]
             for arg in model_args:
                 if len(arg) == 1:
                     k, v = arg[0], None
                 elif len(arg) == 2:
                     k, v = arg
                 if k == 'mu':
                     arg_args.append('-qld.mu')
                     arg_args.append(v)
                 else:
                     raise ValueError(f'unknown ql parameter {arg}')
         elif model.startswith('sdm'):
             arg_args.append('-sdm')
             arg_args.append('-qld')
             model_args = [arg.split('-', 1) for arg in model.split('_')[1:]]
             for arg in model_args:
                 if len(arg) == 1:
                     k, v = arg[0], None
                 elif len(arg) == 2:
                     k, v = arg
                 if k == 'mu':
                     arg_args.append('-qld.mu')
                     arg_args.append(v)
                 elif k == 'tw':
                     arg_args.append('-sdm.tw')
                     arg_args.append(v)
                 elif k == 'ow':
                     arg_args.append('-sdm.ow')
                     arg_args.append(v)
                 elif k == 'uw':
                     arg_args.append('-sdm.uw')
                     arg_args.append(v)
                 else:
                     raise ValueError(f'unknown sdm parameter {arg}')
         else:
             raise ValueError(f'unknown model {model}')
         parser.parseArgument(*arg_args)
         with contextlib.ExitStack() as stack:
             stack.enter_context(J.listen_java_log(_surpress_log('io.anserini.search.SearchCollection')))
             if not quiet:
                 pbar = stack.enter_context(logger.pbar_raw(desc=f'batch_query ({model})', total=total_topics))
                 stack.enter_context(J.listen_java_log(pbar_bq_listener(pbar)))
             searcher = J.A_SearchCollection(args)
             searcher.runTopics()
             searcher.close()
         if destf:
             shutil.copy(run_f, destf)
         else:
             return trec.read_run_dict(run_f)
Exemplo n.º 4
0
    def init(self, force=False):
        idxs = [self.index_stem, self.doc_store]
        self._init_indices_parallel(idxs, self._init_iter_collection(), force)

        base_path = util.path_dataset(self)

        needs_queries = []
        if force or not os.path.exists(
                os.path.join(base_path, 'train.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'train.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.train.tsv' and qid not in MINI_DEV)))
        if force or not os.path.exists(
                os.path.join(base_path, 'minidev.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'minidev.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.train.tsv' and qid in MINI_DEV)))
        if force or not os.path.exists(
                os.path.join(base_path, 'dev.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'dev.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.dev.tsv')))
        if force or not os.path.exists(
                os.path.join(base_path, 'eval.queries.tsv')):
            needs_queries.append(lambda it: plaintext.write_tsv(
                os.path.join(base_path, 'eval.queries.tsv'),
                ((qid, txt) for file, qid, txt in it
                 if file == 'queries.eval.tsv')))

        if needs_queries and self._confirm_dua():
            with util.download_tmp(_SOURCES['queries']) as f, \
                 tarfile.open(fileobj=f) as tarf, \
                 contextlib.ExitStack() as ctxt:

                def _extr_subf(subf):
                    for qid, txt in plaintext.read_tsv(
                            io.TextIOWrapper(tarf.extractfile(subf))):
                        yield subf, qid, txt

                query_iter = [
                    _extr_subf('queries.train.tsv'),
                    _extr_subf('queries.dev.tsv'),
                    _extr_subf('queries.eval.tsv')
                ]
                query_iter = tqdm(itertools.chain(*query_iter), desc='queries')
                query_iters = util.blocking_tee(query_iter, len(needs_queries))
                for fn, it in zip(needs_queries, query_iters):
                    ctxt.enter_context(
                        util.CtxtThread(functools.partial(fn, it)))

        file = os.path.join(base_path, 'train.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['train-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    if qid not in MINI_DEV:
                        trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'minidev.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['train-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    if qid in MINI_DEV:
                        trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'dev.qrels')
        if (force or not os.path.exists(file)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['dev-qrels'], 'utf8')
            with util.finialized_file(file, 'wt') as out:
                for qid, _, did, score in plaintext.read_tsv(stream):
                    trec.write_qrels(out, [(qid, did, score)])

        file = os.path.join(base_path, 'train.mspairs.gz')
        if not os.path.exists(file) and os.path.exists(
                os.path.join(base_path, 'qidpidtriples.train.full')):
            # legacy
            os.rename(os.path.join(base_path, 'qidpidtriples.train.full'),
                      file)
        if (force or not os.path.exists(file)) and self._confirm_dua():
            util.download(_SOURCES['qidpidtriples.train.full'], file)

        if not self.config['init_skip_msrun']:
            for file_name, subf in [('dev.msrun', 'top1000.dev'),
                                    ('eval.msrun', 'top1000.eval'),
                                    ('train.msrun', 'top1000.train.txt')]:
                file = os.path.join(base_path, file_name)
                if (force or not os.path.exists(file)) and self._confirm_dua():
                    run = {}
                    with util.download_tmp(_SOURCES[file_name]) as f, \
                         tarfile.open(fileobj=f) as tarf:
                        for qid, did, _, _ in tqdm(
                                plaintext.read_tsv(
                                    io.TextIOWrapper(tarf.extractfile(subf)))):
                            if qid not in run:
                                run[qid] = {}
                            run[qid][did] = 0.
                    if file_name == 'train.msrun':
                        minidev = {
                            qid: dids
                            for qid, dids in run.items() if qid in MINI_DEV
                        }
                        with self.logger.duration('writing minidev.msrun'):
                            trec.write_run_dict(
                                os.path.join(base_path, 'minidev.msrun'),
                                minidev)
                        run = {
                            qid: dids
                            for qid, dids in run.items() if qid not in MINI_DEV
                        }
                    with self.logger.duration(f'writing {file_name}'):
                        trec.write_run_dict(file, run)

        query_path = os.path.join(base_path, 'trec2019.queries.tsv')
        if (force or not os.path.exists(query_path)) and self._confirm_dua():
            stream = util.download_stream(_SOURCES['trec2019.queries'], 'utf8')
            plaintext.write_tsv(query_path, plaintext.read_tsv(stream))
        msrun_path = os.path.join(base_path, 'trec2019.msrun')
        if (force or not os.path.exists(msrun_path)) and self._confirm_dua():
            run = {}
            with util.download_stream(_SOURCES['trec2019.msrun'],
                                      'utf8') as stream:
                for qid, did, _, _ in plaintext.read_tsv(stream):
                    if qid not in run:
                        run[qid] = {}
                    run[qid][did] = 0.
            with util.finialized_file(msrun_path, 'wt') as f:
                trec.write_run_dict(f, run)

        qrels_path = os.path.join(base_path, 'trec2019.qrels')
        if not os.path.exists(qrels_path) and self._confirm_dua():
            util.download(_SOURCES['trec2019.qrels'], qrels_path)
        qrels_path = os.path.join(base_path, 'judgedtrec2019.qrels')
        if not os.path.exists(qrels_path):
            os.symlink('trec2019.qrels', qrels_path)
        query_path = os.path.join(base_path, 'judgedtrec2019.queries.tsv')
        judged_qids = util.Lazy(
            lambda: trec.read_qrels_dict(qrels_path).keys())
        if (force or not os.path.exists(query_path)):
            with util.finialized_file(query_path, 'wt') as f:
                for qid, qtext in plaintext.read_tsv(
                        os.path.join(base_path, 'trec2019.queries.tsv')):
                    if qid in judged_qids():
                        plaintext.write_tsv(f, [(qid, qtext)])
        msrun_path = os.path.join(base_path, 'judgedtrec2019.msrun')
        if (force or not os.path.exists(msrun_path)) and self._confirm_dua():
            with util.finialized_file(msrun_path, 'wt') as f:
                for qid, dids in trec.read_run_dict(
                        os.path.join(base_path, 'trec2019.msrun')).items():
                    if qid in judged_qids():
                        trec.write_run_dict(f, {qid: dids})

        # A subset of dev that only contains queries that have relevance judgments
        judgeddev_path = os.path.join(base_path, 'judgeddev')
        judged_qids = util.Lazy(lambda: trec.read_qrels_dict(
            os.path.join(base_path, 'dev.qrels')).keys())
        if not os.path.exists(f'{judgeddev_path}.qrels'):
            os.symlink('dev.qrels', f'{judgeddev_path}.qrels')
        if not os.path.exists(f'{judgeddev_path}.queries.tsv'):
            with util.finialized_file(f'{judgeddev_path}.queries.tsv',
                                      'wt') as f:
                for qid, qtext in plaintext.read_tsv(
                        os.path.join(base_path, 'dev.queries.tsv')):
                    if qid in judged_qids():
                        plaintext.write_tsv(f, [(qid, qtext)])
        if self.config['init_skip_msrun']:
            if not os.path.exists(f'{judgeddev_path}.msrun'):
                with util.finialized_file(f'{judgeddev_path}.msrun',
                                          'wt') as f:
                    for qid, dids in trec.read_run_dict(
                            os.path.join(base_path, 'dev.msrun')).items():
                        if qid in judged_qids():
                            trec.write_run_dict(f, {qid: dids})

        if not self.config['init_skip_train10']:
            file = os.path.join(base_path, 'train10.queries.tsv')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout:
                    for qid, qtext in self.logger.pbar(
                            plaintext.read_tsv(
                                os.path.join(base_path, 'train.queries.tsv')),
                            desc='filtering queries for train10'):
                        if int(qid) % 10 == 0:
                            plaintext.write_tsv(fout, [(qid, qtext)])

            file = os.path.join(base_path, 'train10.qrels')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout, open(
                        os.path.join(base_path, 'train.qrels'), 'rt') as fin:
                    for line in self.logger.pbar(
                            fin, desc='filtering qrels for train10'):
                        qid = line.split()[0]
                        if int(qid) % 10 == 0:
                            fout.write(line)

            if not self.config['init_skip_msrun']:
                file = os.path.join(base_path, 'train10.msrun')
                if not os.path.exists(file):
                    with util.finialized_file(file, 'wt') as fout, open(
                            os.path.join(base_path, 'train.msrun'),
                            'rt') as fin:
                        for line in self.logger.pbar(
                                fin, desc='filtering msrun for train10'):
                            qid = line.split()[0]
                            if int(qid) % 10 == 0:
                                fout.write(line)

            file = os.path.join(base_path, 'train10.mspairs.gz')
            if not os.path.exists(file):
                with gzip.open(file, 'wt') as fout, gzip.open(
                        os.path.join(base_path, 'train.mspairs.gz'),
                        'rt') as fin:
                    for qid, did1, did2 in self.logger.pbar(
                            plaintext.read_tsv(fin),
                            desc='filtering mspairs for train10'):
                        if int(qid) % 10 == 0:
                            plaintext.write_tsv(fout, [(qid, did1, did2)])

        if not self.config['init_skip_train_med']:
            med_qids = util.Lazy(
                lambda: {
                    qid.strip()
                    for qid in util.download_stream(
                        'https://raw.githubusercontent.com/Georgetown-IR-Lab/covid-neural-ir/master/med-msmarco-train.txt',
                        'utf8',
                        expected_md5="dc5199de7d4a872c361f89f08b1163ef")
                })
            file = os.path.join(base_path, 'train_med.queries.tsv')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout:
                    for qid, qtext in self.logger.pbar(
                            plaintext.read_tsv(
                                os.path.join(base_path, 'train.queries.tsv')),
                            desc='filtering queries for train_med'):
                        if qid in med_qids():
                            plaintext.write_tsv(fout, [(qid, qtext)])

            file = os.path.join(base_path, 'train_med.qrels')
            if not os.path.exists(file):
                with util.finialized_file(file, 'wt') as fout, open(
                        os.path.join(base_path, 'train.qrels'), 'rt') as fin:
                    for line in self.logger.pbar(
                            fin, desc='filtering qrels for train_med'):
                        qid = line.split()[0]
                        if qid in med_qids():
                            fout.write(line)

            if not self.config['init_skip_msrun']:
                file = os.path.join(base_path, 'train_med.msrun')
                if not os.path.exists(file):
                    with util.finialized_file(file, 'wt') as fout, open(
                            os.path.join(base_path, 'train.msrun'),
                            'rt') as fin:
                        for line in self.logger.pbar(
                                fin, desc='filtering msrun for train_med'):
                            qid = line.split()[0]
                            if qid in med_qids():
                                fout.write(line)

            file = os.path.join(base_path, 'train_med.mspairs.gz')
            if not os.path.exists(file):
                with gzip.open(file, 'wt') as fout, gzip.open(
                        os.path.join(base_path, 'train.mspairs.gz'),
                        'rt') as fin:
                    for qid, did1, did2 in self.logger.pbar(
                            plaintext.read_tsv(fin),
                            desc='filtering mspairs for train_med'):
                        if qid in med_qids():
                            plaintext.write_tsv(fout, [(qid, did1, did2)])