def setUp(self):
        curdir = os.getcwd()
        if curdir.endswith('sparse'):
            self.pyserini_root = '../..'
        else:
            self.pyserini_root = '.'

        if (os.path.isdir('ibm_test')):
            rmtree('ibm_test')
            os.mkdir('ibm_test')
        #Download prebuilt index
        SimpleSearcher.from_prebuilt_index('msmarco-passage-ltr')
        inp = 'run.msmarco-passage.bm25tuned.trec'
        os.system(
            f'python -m pyserini.search --topics msmarco-passage-dev-subset  --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3/ --output ibm_test/{inp} --bm25 --output-format trec --hits 1000 --k1 0.82 --b 0.68'
        )
        #ibm model
        ibm_model_url = 'https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-models/ibm_model_1_bert_tok_20211117.tar.gz'
        ibm_model_tar_name = 'ibm_model_1_bert_tok_20211117.tar.gz'
        os.system(f'wget {ibm_model_url} -P ibm_test/')
        os.system(f'tar -xzvf ibm_test/{ibm_model_tar_name} -C ibm_test')
        #queries process
        os.system(
            'python scripts/ltr_msmarco/convert_queries.py --input tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt --output ibm_test/queries.dev.small.json'
        )
Ejemplo n.º 2
0
 def test_reranking(self):
     if(os.path.isdir('ltr_test')):
         rmtree('ltr_test')
         os.mkdir('ltr_test')
     inp = 'run.msmarco-passage.bm25tuned.txt'
     outp = 'run.ltr.msmarco-passage.test.tsv'
     #Download candidate
     os.system('wget https://www.dropbox.com/s/bjyzf65uns2is61/run.msmarco-passage.bm25tuned.txt -P ltr_test')
     #Download prebuilt index
     SimpleSearcher.from_prebuilt_index('msmarco-passage-ltr')
     #Pre-trained ltr model
     model_url = 'https://www.dropbox.com/s/ffl2bfw4cd5ngyz/msmarco-passage-ltr-mrr-v1.tar.gz'
     model_tar_name = 'msmarco-passage-ltr-mrr-v1.tar.gz'
     os.system(f'wget {model_url} -P ltr_test/')
     os.system(f'tar -xzvf ltr_test/{model_tar_name} -C ltr_test')
     #ibm model
     ibm_model_url = 'https://www.dropbox.com/s/vlrfcz3vmr4nt0q/ibm_model.tar.gz'
     ibm_model_tar_name = 'ibm_model.tar.gz'
     os.system(f'wget {ibm_model_url} -P ltr_test/')
     os.system(f'tar -xzvf ltr_test/{ibm_model_tar_name} -C ltr_test')
     #queries process
     os.system('python scripts/ltr_msmarco-passage/convert_queries.py --input tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt --output ltr_test/queries.dev.small.json')
     os.system(f'python -m pyserini.ltr.search_msmarco_passage --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}')
     result = subprocess.check_output(f'python tools/scripts/msmarco/msmarco_passage_eval.py tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt ltr_test/{outp}', shell=True).decode(sys.stdout.encoding)
     a,b = result.find('#####################\nMRR @10:'), result.find('\nQueriesRanked: 6980\n#####################\n')
     mrr = result[a+31:b]
     self.assertAlmostEqual(float(mrr),0.24709612498294367, delta=0.000001)
     rmtree('ltr_test')
Ejemplo n.º 3
0
class MsMarcoDemo(cmd.Cmd):
    searcher = SimpleSearcher.from_prebuilt_index('msmarco-passage')
    k = 10
    prompt = '>>> '

    # https://stackoverflow.com/questions/35213134/command-prefixes-in-python-cli-using-cmd-in-pythons-standard-library
    def precmd(self, line):
        if line[0] == '/':
            line = line[1:]
        return line

    def do_help(self, arg):
        print(f'/help    : returns this message')
        print(f'/k [NUM] : sets k (number of hits to return) to [NUM]')

    def do_k(self, arg):
        print(f'setting k = {int(arg)}')
        self.k = int(arg)

    def do_EOF(self, line):
        return True

    def default(self, q):
        hits = self.searcher.search(q, self.k)

        for i in range(0, len(hits)):
            jsondoc = json.loads(hits[i].raw)
            print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}')
Ejemplo n.º 4
0
def main(args):
    if args.cache and not os.path.exists(args.cache):
        os.mkdir(args.cache)

    # Load queries:
    queries = load_queries(args.queries)
    # Load base run to rerank:
    base_run = TrecRun(args.input)

    # SimpleSearcher to fetch document texts.
    searcher = SimpleSearcher.from_prebuilt_index('msmarco-doc')

    output = []

    if args.bm25:
        reranker = 'bm25'
    elif args.ance:
        reranker = 'ance'
    elif not args.identity:
        sys.exit('Unknown reranking method!')

    cnt = 1
    for row in queries:
        qid = int(row[0])
        query = row[1]
        print(f'{cnt} {qid} {query}')
        qid_results = base_run.get_docs_by_topic(qid)

        # Don't actually do reranking, just pass along the base run:
        if args.identity:
            rank = 1
            for docid in qid_results['docid'].tolist():
                output.append([qid, docid, rank])
                rank = rank + 1
            cnt = cnt + 1
            continue

        # Gather results for reranking:
        results_to_rerank = []
        for index, result in qid_results.iterrows():
            raw_doc = searcher.doc(
                result['docid']).raw().lstrip('<TEXT>').rstrip('</TEXT>')
            results_to_rerank.append({
                'docid': result['docid'],
                'rank': result['rank'],
                'score': result['score'],
                'text': raw_doc
            })

        # Perform the actual reranking:
        output.extend(
            rerank(args.cache, qid, query, results_to_rerank, reranker))
        cnt = cnt + 1

    # Write the output run file:
    with open(args.output, 'w') as writer:
        for r in output:
            writer.write(f'{r[0]}\t{r[1]}\t{r[2]}\n')
    def test_reranking(self):
        if (os.path.isdir('ltr_test')):
            rmtree('ltr_test')
            os.mkdir('ltr_test')
        inp = 'run.msmarco-pass-doc.bm25.txt'
        outp = 'run.ltr.msmarco-pass-doc.test.trec'
        outp_tsv = 'run.ltr.msmarco-pass-doc.test.tsv'
        #Download prebuilt index
        #retrieve candidate
        SimpleSearcher.from_prebuilt_index('msmarco-doc-per-passage-ltr')
        os.system(
            f'python -m pyserini.search --topics msmarco-doc-dev  --index ~/.cache/pyserini/indexes/index-msmarco-doc-per-passage-ltr-20211031-33e4151.bd60e89041b4ebbabc4bf0cfac608a87/ --output ltr_test/{inp} --bm25 --output-format trec --hits 10000'
        )
        #Pre-trained ltr model
        model_url = 'https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-models/model-ltr-msmarco-passage-mrr-v1.tar.gz'
        model_tar_name = 'model-ltr-msmarco-passage-mrr-v1.tar.gz'
        os.system(f'wget {model_url} -P ltr_test/')
        os.system(f'tar -xzvf ltr_test/{model_tar_name} -C ltr_test')
        #ibm model
        ibm_model_url = 'https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-models/model-ltr-ibm.tar.gz'
        ibm_model_tar_name = 'model-ltr-ibm.tar.gz'
        os.system(f'wget {ibm_model_url} -P ltr_test/')
        #queries process
        os.system(f'tar -xzvf ltr_test/{ibm_model_tar_name} -C ltr_test')
        os.system(
            'python scripts/ltr_msmarco/convert_queries.py --input tools/topics-and-qrels/topics.msmarco-doc.dev.txt --output ltr_test/queries.dev.small.json'
        )
        os.system(
            f'python scripts/ltr_msmarco/ltr_inference.py  --input ltr_test/{inp} --input-format trec --data document --model ltr_test/msmarco-passage-ltr-mrr-v1/ --index ~/.cache/pyserini/indexes/index-msmarco-doc-per-passage-ltr-20211031-33e4151.bd60e89041b4ebbabc4bf0cfac608a87 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}'
        )
        #convert trec to tsv withmaxP
        os.system(
            f'python scripts/ltr_msmarco/generate_document_score_withmaxP.py --input ltr_test/{outp} --output ltr_test/{outp_tsv}'
        )

        result = subprocess.check_output(
            f'python tools/scripts/msmarco/msmarco_doc_eval.py --judgments tools/topics-and-qrels/qrels.msmarco-doc.dev.txt --run ltr_test/{outp_tsv}',
            shell=True).decode(sys.stdout.encoding)
        a, b = result.find('#####################\nMRR @100:'), result.find(
            '\nQueriesRanked: 5193\n#####################\n')
        mrr = result[a + 32:b]
        # See https://github.com/castorini/pyserini/issues/951
        self.assertAlmostEqual(float(mrr), 0.3091, delta=0.0001)
        rmtree('ltr_test')
Ejemplo n.º 6
0
 def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str], prebuilt_index_name: Optional[str] = None):
     requires_backends(self, "faiss")
     if isinstance(query_encoder, QueryEncoder):
         self.query_encoder = query_encoder
     else:
         self.query_encoder = self._init_encoder_from_str(query_encoder)
     self.index, self.docids = self.load_index(index_dir)
     self.dimension = self.index.d
     self.num_docs = self.index.ntotal
     
     assert self.docids is None or self.num_docs == len(self.docids)
     if prebuilt_index_name:
         sparse_index = get_sparse_index(prebuilt_index_name)
         self.ssearcher = SimpleSearcher.from_prebuilt_index(sparse_index)
Ejemplo n.º 7
0
 def test_reranking(self):
     if (os.path.isdir('ltr_test')):
         rmtree('ltr_test')
         os.mkdir('ltr_test')
     inp = 'run.msmarco-passage.bm25tuned.txt'
     outp = 'run.ltr.msmarco-passage.test.tsv'
     #Download prebuilt index
     SimpleSearcher.from_prebuilt_index('msmarco-passage-ltr')
     os.system(
         f'python -m pyserini.search --topics msmarco-passage-dev-subset  --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3/ --output ltr_test/{inp} --bm25 --output-format msmarco --hits 1000 --k1 0.82 --b 0.68'
     )
     #Pre-trained ltr model
     model_url = 'https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-models/model-ltr-msmarco-passage-mrr-v1.tar.gz'
     model_tar_name = 'model-ltr-msmarco-passage-mrr-v1.tar.gz'
     os.system(f'wget {model_url} -P ltr_test/')
     os.system(f'tar -xzvf ltr_test/{model_tar_name} -C ltr_test')
     #ibm model
     ibm_model_url = 'https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-models/model-ltr-ibm.tar.gz'
     ibm_model_tar_name = 'model-ltr-ibm.tar.gz'
     os.system(f'wget {ibm_model_url} -P ltr_test/')
     os.system(f'tar -xzvf ltr_test/{ibm_model_tar_name} -C ltr_test')
     #queries process
     os.system(
         'python scripts/ltr_msmarco/convert_queries.py --input tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt --output ltr_test/queries.dev.small.json'
     )
     os.system(
         f'python scripts/ltr_msmarco/ltr_inference.py  --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --data passage --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output-format tsv --output ltr_test/{outp}'
     )
     result = subprocess.check_output(
         f'python tools/scripts/msmarco/msmarco_passage_eval.py tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt ltr_test/{outp}',
         shell=True).decode(sys.stdout.encoding)
     a, b = result.find('#####################\nMRR @10:'), result.find(
         '\nQueriesRanked: 6980\n#####################\n')
     mrr = result[a + 31:b]
     # See https://github.com/castorini/pyserini/issues/951
     self.assertAlmostEqual(float(mrr), 0.2472, delta=0.0001)
     rmtree('ltr_test')
Ejemplo n.º 8
0
def build_searcher(settings: SearcherSettings) -> SimpleSearcher:
    if path.isdir(settings.index_path):
        searcher = SimpleSearcher(settings.index_path)
    else:
        searcher = SimpleSearcher.from_prebuilt_index(settings.index_path)
    searcher.set_bm25(float(settings.k1), float(settings.b))
    logging.info(
        "Initializing BM25, setting k1={} and b={}".format(settings.k1, settings.b)
    )
    if settings.rm3:
        searcher.set_rm3(
            settings.fb_terms, settings.fb_docs, settings.original_query_weight
        )
        logging.info(
            "Initializing RM3, setting fbTerms={}, fbDocs={} and originalQueryWeight={}".format(
                settings.fb_terms, settings.fb_docs, settings.original_query_weight
            )
        )
    return searcher
Ejemplo n.º 9
0
import sys

# Only for debugging purposes, using a Pyserini local installation.
# sys.path.insert(0, '../pyserini/')

from fastapi import FastAPI
from pyserini.search import SimpleSearcher
from typing import Optional

searcher = SimpleSearcher.from_prebuilt_index('msmarco-doc-expanded-per-doc')
searcher.set_bm25(4.68, 0.87)
app = FastAPI()


@app.get("/search/")
def search(q: str, k: Optional[int] = 1000):
    hits = searcher.search(q, k=k)

    results = []
    for i in range(0, len(hits)):
        results.append({'docid': hits[i].docid, 'score': hits[i].score})

    return {'results': results}
Ejemplo n.º 10
0
    parser.add_argument('--reader-model', type=str, required=False, help="Reader model name or path")
    parser.add_argument('--reader-device', type=str, required=False, default='cuda:0', help="Device to run inference on")

    args = parser.parse_args()

    # check arguments
    arg_check(args, parser)

    print("Init QA models")
    if args.type == 'openbook':
        if args.qa_reader == 'dpr':
            reader = DprReader(args.reader_model, device=args.reader_device)
            if args.retriever_model:
                retriever = SimpleDenseSearcher(args.retriever_index, DprQueryEncoder(args.retriever_model))
            else:
                retriever = SimpleSearcher.from_prebuilt_index(args.retriever_corpus)
            corpus = SimpleSearcher.from_prebuilt_index(args.retriever_corpus)
            obqa = OpenBookQA(reader, retriever, corpus)
            # run a warm up question
            obqa.predict('what is lobster roll')
            while True:
                question = input('Enter a question: ')
                answer = obqa.predict(question)
                answer_text = answer["answer"]
                answer_context = answer["context"]["text"]
                print(f"Answer:\t {answer_text}")
                print(f"Context:\t {answer_context}")
        elif args.qa_reader == 'fid':
            reader = FidReader(model_name=args.reader_model, device=args.reader_device)
            if args.retriever_model:
                # retriever = SimpleDenseSearcher(args.retriever_index, DkrrDprQueryEncoder(args.retriever_model))
Ejemplo n.º 11
0
 def test_custom_cache(self):
     os.environ['PYSERINI_CACHE'] = 'temp_dir'
     SimpleSearcher.from_prebuilt_index('cacm')
     self.assertTrue(os.path.exists('temp_dir/indexes'))
Ejemplo n.º 12
0
 def test_default_cache(self):
     SimpleSearcher.from_prebuilt_index('cacm')
     self.assertTrue(os.path.exists(os.path.expanduser('~/.cache/pyserini/indexes')))
Ejemplo n.º 13
0
    if not searcher:
        exit()

    # Check PRF Flag
    if args.prf_depth > 0 and type(searcher) == SimpleDenseSearcher:
        PRF_FLAG = True
        if args.prf_method.lower() == 'avg':
            prfRule = DenseVectorAveragePrf()
        elif args.prf_method.lower() == 'rocchio':
            prfRule = DenseVectorRocchioPrf(args.rocchio_alpha, args.rocchio_beta)
        # ANCE-PRF is using a new query encoder, so the input to DenseVectorAncePrf is different
        elif args.prf_method.lower() == 'ance-prf' and type(query_encoder) == AnceQueryEncoder:
            if os.path.exists(args.sparse_index):
                sparse_searcher = SimpleSearcher(args.sparse_index)
            else:
                sparse_searcher = SimpleSearcher.from_prebuilt_index(args.sparse_index)
            prf_query_encoder = AnceQueryEncoder(encoder_dir=args.ance_prf_encoder, tokenizer_name=args.tokenizer,
                                                 device=args.device)
            prfRule = DenseVectorAncePrf(prf_query_encoder, sparse_searcher)
        print(f'Running SimpleDenseSearcher with {args.prf_method.upper()} PRF...')
    else:
        PRF_FLAG = False

    # build output path
    output_path = args.output

    print(f'Running {args.topics} topics, saving to {output_path}...')
    tag = 'Faiss'

    output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w',
                                      max_hits=args.hits, tag=tag, topics=topics,
Ejemplo n.º 14
0
class DPRDemo(cmd.Cmd):
    nq_dev_topics = list(search.get_topics('dpr-nq-dev').values())
    trivia_dev_topics = list(search.get_topics('dpr-trivia-dev').values())

    ssearcher = SimpleSearcher.from_prebuilt_index('wikipedia-dpr')
    searcher = ssearcher

    encoder = DprQueryEncoder("facebook/dpr-question_encoder-multiset-base")
    index = 'wikipedia-dpr-multi-bf'
    dsearcher = SimpleDenseSearcher.from_prebuilt_index(
        index,
        encoder
    )
    hsearcher = HybridSearcher(dsearcher, ssearcher)

    k = 10
    prompt = '>>> '

    def precmd(self, line):
        if line[0] == '/':
            line = line[1:]
        return line

    def do_help(self, arg):
        print(f'/help    : returns this message')
        print(f'/k [NUM] : sets k (number of hits to return) to [NUM]')
        print(f'/mode [MODE] : sets retriver type to [MODE] (one of sparse, dense, hybrid)')
        print(f'/random [COLLECTION]: returns results for a random question from the dev subset [COLLECTION] (one of nq, trivia).')

    def do_k(self, arg):
        print(f'setting k = {int(arg)}')
        self.k = int(arg)

    def do_mode(self, arg):
        if arg == "sparse":
            self.searcher = self.ssearcher
        elif arg == "dense":
            self.searcher = self.dsearcher
        elif arg == "hybrid":
            self.searcher = self.hsearcher
        else:
            print(
                f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].')
            return
        print(f'setting retriver = {arg}')

    def do_random(self, arg):
        if arg == "nq":
            topics = self.nq_dev_topics
        elif arg == "trivia":
            topics = self.trivia_dev_topics
        else:
            print(
                f'Collection "{arg}" is invalid. Collection should be one of [nq, trivia].')
            return
        q = random.choice(topics)['title']
        print(f'question: {q}')
        self.default(q)

    def do_EOF(self, line):
        return True

    def default(self, q):
        hits = self.searcher.search(q, self.k)

        for i in range(0, len(hits)):
            raw_doc = None
            if isinstance(self.searcher, SimpleSearcher):
                raw_doc = hits[i].raw
            else:
                doc = self.searcher.doc(hits[i].docid)
                if doc:
                    raw_doc = doc.raw()
            jsondoc = json.loads(raw_doc)
            print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}')
Ejemplo n.º 15
0
import sys

# Only for debugging purposes, using a Pyserini local installation.
# sys.path.insert(0, '../pyserini/')

from fastapi import FastAPI
from pyserini.search import SimpleSearcher
from typing import Optional


searcher = SimpleSearcher.from_prebuilt_index('msmarco-doc-slim')
app = FastAPI()


@app.get("/search/")
def search(q: str, k: Optional[int] = 1000):
    hits = searcher.search(q, k=k)

    results = []
    for i in range(0, len(hits)):
        results.append({'docid': hits[i].docid, 'score': hits[i].score})

    return {'results': results}
Ejemplo n.º 16
0
class MsMarcoDemo(cmd.Cmd):
    dev_topics = list(search.get_topics('msmarco-passage-dev-subset').values())

    ssearcher = SimpleSearcher.from_prebuilt_index('msmarco-passage')
    dsearcher = None
    hsearcher = None
    searcher = ssearcher

    k = 10
    prompt = '>>> '

    # https://stackoverflow.com/questions/35213134/command-prefixes-in-python-cli-using-cmd-in-pythons-standard-library
    def precmd(self, line):
        if line[0] == '/':
            line = line[1:]
        return line

    def do_help(self, arg):
        print(f'/help    : returns this message')
        print(f'/k [NUM] : sets k (number of hits to return) to [NUM]')
        print(
            f'/model [MODEL] : sets encoder to use the model [MODEL] (one of tct, ance)'
        )
        print(
            f'/mode [MODE] : sets retriever type to [MODE] (one of sparse, dense, hybrid)'
        )
        print(
            f'/random : returns results for a random question from dev subset')

    def do_k(self, arg):
        print(f'setting k = {int(arg)}')
        self.k = int(arg)

    def do_mode(self, arg):
        if arg == "sparse":
            self.searcher = self.ssearcher
        elif arg == "dense":
            if self.dsearcher is None:
                print(
                    f'Specify model through /model before using dense retrieval.'
                )
                return
            self.searcher = self.dsearcher
        elif arg == "hybrid":
            if self.hsearcher is None:
                print(
                    f'Specify model through /model before using hybrid retrieval.'
                )
                return
            self.searcher = self.hsearcher
        else:
            print(
                f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].'
            )
            return
        print(f'setting retriver = {arg}')

    def do_model(self, arg):
        if arg == "tct":
            encoder = TctColBertQueryEncoder("castorini/tct_colbert-msmarco")
            index = "msmarco-passage-tct_colbert-hnsw"
        elif arg == "ance":
            encoder = AnceQueryEncoder("castorini/ance-msmarco-passage")
            index = "msmarco-passage-ance-bf"
        else:
            print(
                f'Model "{arg}" is invalid. Model should be one of [tct, ance].'
            )
            return

        self.dsearcher = SimpleDenseSearcher.from_prebuilt_index(
            index, encoder)
        self.hsearcher = HybridSearcher(self.dsearcher, self.ssearcher)
        print(f'setting model = {arg}')

    def do_random(self, arg):
        q = random.choice(self.dev_topics)['title']
        print(f'question: {q}')
        self.default(q)

    def do_EOF(self, line):
        return True

    def default(self, q):
        hits = self.searcher.search(q, self.k)

        for i in range(0, len(hits)):
            raw_doc = None
            if isinstance(self.searcher, SimpleSearcher):
                raw_doc = hits[i].raw
            else:
                doc = self.searcher.doc(hits[i].docid)
                if doc:
                    raw_doc = doc.raw()
            jsondoc = json.loads(raw_doc)
            print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}')
Ejemplo n.º 17
0
    parser.add_argument('--max-passage-delimiter',
                        type=str,
                        metavar='str',
                        required=False,
                        default='#',
                        help="Delimiter between docid and passage id.")
    args = parser.parse_args()

    topics = get_topics(args.topics)

    if os.path.exists(args.index):
        # create searcher from index directory
        searcher = SimpleSearcher(args.index)
    else:
        # create searcher from prebuilt index name
        searcher = SimpleSearcher.from_prebuilt_index(args.index)

    if not searcher:
        exit()

    search_rankers = []

    if args.qld:
        search_rankers.append('qld')
        searcher.set_qld()
    else:
        search_rankers.append('bm25')

        if args.k1 is not None or args.b is not None:
            if args.k1 is None or args.b is None:
                print('Must set *both* k1 and b for BM25!')
Ejemplo n.º 18
0
 def test_default_cache(self):
     os.unsetenv('PYSERINI_CACHE')
     SimpleSearcher.from_prebuilt_index('cacm')
     self.assertTrue(
         os.path.exists(os.path.expanduser('~/.cache/pyserini/indexes')))