Beispiel #1
0
def bm25(qid, query, docs, index_path):
    s = LuceneSearcher(index_path)
    hits = s.search(query, 1000)

    n = 1
    seen_docids = {}
    with open(f'run-passage-{qid}.txt', 'w') as writer:
        for i in range(0, len(hits)):
            if hits[i].docid in seen_docids:
                continue
            writer.write(f'{qid} Q0 {hits[i].docid} {n} {hits[i].score:.5f} pyserini\n')
            n = n + 1
            seen_docids[hits[i].docid] = 1

    with open(f'run-doc-{qid}.txt', 'w') as writer:
        for doc in docs:
            writer.write(f'{qid} Q0 {doc["docid"]} {doc["rank"]} {doc["score"]} base\n')
            n = n + 1

    os.system(f'python -m pyserini.fusion --method rrf --runs run-passage-{qid}.txt run-doc-{qid}.txt ' +
              f'--output run-rrf-{qid}.txt --runtag test')
    fused_run = TrecRun(f'run-rrf-{qid}.txt')

    output = []
    for idx, r in fused_run.get_docs_by_topic(qid).iterrows():
        output.append([qid, r["docid"], r["rank"]])

    return output
Beispiel #2
0
    def setUp(self):
        # Download pre-built CACM index; append a random value to avoid filename clashes.
        r = randint(0, 10000000)
        self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz'
        self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r)
        self.index_dir = 'index{}/'.format(r)

        filename, headers = urlretrieve(self.collection_url, self.tarball_name)

        tarball = tarfile.open(self.tarball_name)
        tarball.extractall(self.index_dir)
        tarball.close()

        self.searcher = LuceneSearcher(f'{self.index_dir}lucene-index.cacm')
    def __init__(self, name, num_threads, index_dir=None, k1=0.9, b=0.4, use_bigrams=False, stem_bigrams=False):
        super().__init__(name)

        self.num_threads = min(num_threads, int(multiprocessing.cpu_count()))

        # initialize a ranker per thread
        self.arguments = []
        for id in tqdm(range(self.num_threads)):
            ranker = LuceneSearcher(index_dir)
            ranker.set_bm25(k1, b)
            self.arguments.append(
                {
                    "id": id,
                    "ranker": ranker,
                    "use_bigrams": use_bigrams,
                    "stem_bigrams": stem_bigrams
                }
            )
Beispiel #4
0
def main(args):
    if args.cache and not os.path.exists(args.cache):
        os.mkdir(args.cache)

    # Load queries:
    queries = load_queries(args.queries)
    # Load base run to rerank:
    base_run = TrecRun(args.input)

    # LuceneSearcher to fetch document texts.
    searcher = LuceneSearcher.from_prebuilt_index('msmarco-doc')

    output = []

    if args.bm25:
        reranker = 'bm25'
    elif args.ance:
        reranker = 'ance'
    elif not args.identity:
        sys.exit('Unknown reranking method!')

    cnt = 1
    for row in queries:
        qid = int(row[0])
        query = row[1]
        print(f'{cnt} {qid} {query}')
        qid_results = base_run.get_docs_by_topic(qid)

        # Don't actually do reranking, just pass along the base run:
        if args.identity:
            rank = 1
            for docid in qid_results['docid'].tolist():
                output.append([qid, docid, rank])
                rank = rank + 1
            cnt = cnt + 1
            continue

        # Gather results for reranking:
        results_to_rerank = []
        for index, result in qid_results.iterrows():
            raw_doc = searcher.doc(result['docid']).raw().lstrip('<TEXT>').rstrip('</TEXT>')
            results_to_rerank.append({'docid': result['docid'],
                                      'rank': result['rank'],
                                      'score': result['score'],
                                      'text': raw_doc})

        # Perform the actual reranking:
        output.extend(rerank(args.cache, qid, query, results_to_rerank, reranker))
        cnt = cnt + 1

    # Write the output run file:
    with open(args.output, 'w') as writer:
        for r in output:
            writer.write(f'{r[0]}\t{r[1]}\t{r[2]}\n')
Beispiel #5
0
 def __init__(self, model: str, ibm_model:str, index:str, data: str, prebuilt: bool):
     #msmarco-ltr-passage
     self.model = model
     self.ibm_model = ibm_model
     if prebuilt:
         self.lucene_searcher = LuceneSearcher.from_prebuilt_index(index)
         index_directory = os.path.join(get_cache_home(), 'indexes')
         if data == 'passage':
             index_path = os.path.join(index_directory, 'index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3')
         else:
             index_path = os.path.join(index_directory, 'index-msmarco-doc-per-passage-ltr-20211031-33e4151.bd60e89041b4ebbabc4bf0cfac608a87')
         self.index_reader = IndexReader.from_prebuilt_index(index)
     else:
         index_path = index
         self.index_reader = IndexReader(index)
     self.fe = FeatureExtractor(index_path, max(multiprocessing.cpu_count()//2, 1))
     self.data = data
Beispiel #6
0
    def __init__(self,
                 index_dir: str,
                 query_encoder: Union[QueryEncoder, str],
                 prebuilt_index_name: Optional[str] = None):
        requires_backends(self, "faiss")
        if isinstance(query_encoder, QueryEncoder) or isinstance(
                query_encoder, PcaEncoder):
            self.query_encoder = query_encoder
        else:
            self.query_encoder = self._init_encoder_from_str(query_encoder)
        self.index, self.docids = self.load_index(index_dir)
        self.dimension = self.index.d
        self.num_docs = self.index.ntotal

        assert self.docids is None or self.num_docs == len(self.docids)
        if prebuilt_index_name:
            sparse_index = get_sparse_index(prebuilt_index_name)
            self.ssearcher = LuceneSearcher.from_prebuilt_index(sparse_index)
Beispiel #7
0
 def __init__(self, ibm_model: str, index: str, field_name: str):
     self.ibm_model = ibm_model
     self.bm25search = LuceneSearcher.from_prebuilt_index(index)
     index_directory = os.path.join(get_cache_home(), 'indexes')
     if (index == 'msmarco-passage-ltr'):
         index_path = os.path.join(
             index_directory,
             'index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3'
         )
     elif (index == 'msmarco-document-segment-ltr'):
         index_path = os.path.join(
             index_directory,
             'lucene-index.msmarco-doc-segmented.ibm.13064bdaf8e8a79222634d67ecd3ddb5'
         )
     else:
         print(
             "We currently only support two indexes: msmarco-passage-ltr and msmarco-document-segment-ltr, \
         but the index you inserted is not one of those")
     self.object = JLuceneSearcher(index_path)
     self.index_reader = JIndexReader().getReader(index_path)
     self.field_name = field_name
     self.source_lookup, self.target_lookup, self.tran = self.load_tranprobs_table(
     )
     self.pool = ThreadPool(24)
Beispiel #8
0
 def test_custom_cache(self):
     os.environ['PYSERINI_CACHE'] = 'temp_dir'
     LuceneSearcher.from_prebuilt_index('cacm')
     self.assertTrue(os.path.exists('temp_dir/indexes'))
Beispiel #9
0
class TestSearch(unittest.TestCase):
    def setUp(self):
        # Download pre-built CACM index; append a random value to avoid filename clashes.
        r = randint(0, 10000000)
        self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz'
        self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r)
        self.index_dir = 'index{}/'.format(r)

        filename, headers = urlretrieve(self.collection_url, self.tarball_name)

        tarball = tarfile.open(self.tarball_name)
        tarball.extractall(self.index_dir)
        tarball.close()

        self.searcher = LuceneSearcher(f'{self.index_dir}lucene-index.cacm')

    def test_basic(self):
        self.assertTrue(
            self.searcher.get_similarity().toString().startswith('BM25'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(hits, List))

        self.assertTrue(isinstance(hits[0], JLuceneSearcherResult))
        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertEqual(hits[0].lucene_docid, 3133)
        self.assertEqual(len(hits[0].contents), 1500)
        self.assertEqual(len(hits[0].raw), 1532)
        self.assertAlmostEqual(hits[0].score, 4.76550, places=5)

        # Test accessing the raw Lucene document and fetching fields from it:
        self.assertEqual(hits[0].lucene_document.getField('id').stringValue(),
                         'CACM-3134')
        self.assertEqual(hits[0].lucene_document.get('id'),
                         'CACM-3134')  # simpler call, same result as above
        self.assertEqual(
            len(hits[0].lucene_document.getField('raw').stringValue()), 1532)
        self.assertEqual(len(hits[0].lucene_document.get('raw')),
                         1532)  # simpler call, same result as above

        self.assertTrue(isinstance(hits[9], JLuceneSearcherResult))
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.21740, places=5)

        hits = self.searcher.search('search')

        self.assertTrue(isinstance(hits[0], JLuceneSearcherResult))
        self.assertEqual(hits[0].docid, 'CACM-3058')
        self.assertAlmostEqual(hits[0].score, 2.85760, places=5)

        self.assertTrue(isinstance(hits[9], JLuceneSearcherResult))
        self.assertEqual(hits[9].docid, 'CACM-3040')
        self.assertAlmostEqual(hits[9].score, 2.68780, places=5)

    def test_batch(self):
        results = self.searcher.batch_search(
            ['information retrieval', 'search'], ['q1', 'q2'], threads=2)

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(results, Dict))

        self.assertTrue(isinstance(results['q1'], List))
        self.assertTrue(isinstance(results['q1'][0], JLuceneSearcherResult))
        self.assertEqual(results['q1'][0].docid, 'CACM-3134')
        self.assertAlmostEqual(results['q1'][0].score, 4.76550, places=5)

        self.assertTrue(isinstance(results['q1'][9], JLuceneSearcherResult))
        self.assertEqual(results['q1'][9].docid, 'CACM-2516')
        self.assertAlmostEqual(results['q1'][9].score, 4.21740, places=5)

        self.assertTrue(isinstance(results['q2'], List))
        self.assertTrue(isinstance(results['q2'][0], JLuceneSearcherResult))
        self.assertEqual(results['q2'][0].docid, 'CACM-3058')
        self.assertAlmostEqual(results['q2'][0].score, 2.85760, places=5)

        self.assertTrue(isinstance(results['q2'][9], JLuceneSearcherResult))
        self.assertEqual(results['q2'][9].docid, 'CACM-3040')
        self.assertAlmostEqual(results['q2'][9].score, 2.68780, places=5)

    def test_basic_k(self):
        hits = self.searcher.search('information retrieval', k=100)

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(hits, List))
        self.assertTrue(isinstance(hits[0], JLuceneSearcherResult))
        self.assertEqual(len(hits), 100)

    def test_batch_k(self):
        results = self.searcher.batch_search(
            ['information retrieval', 'search'], ['q1', 'q2'],
            k=100,
            threads=2)

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(results, Dict))
        self.assertTrue(isinstance(results['q1'], List))
        self.assertTrue(isinstance(results['q1'][0], JLuceneSearcherResult))
        self.assertEqual(len(results['q1']), 100)
        self.assertTrue(isinstance(results['q2'], List))
        self.assertTrue(isinstance(results['q2'][0], JLuceneSearcherResult))
        self.assertEqual(len(results['q2']), 100)

    def test_basic_fields(self):
        # This test just provides a sanity check, it's not that interesting as it only searches one field.
        hits = self.searcher.search('information retrieval',
                                    k=42,
                                    fields={'contents': 2.0})

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(hits, List))
        self.assertTrue(isinstance(hits[0], JLuceneSearcherResult))
        self.assertEqual(len(hits), 42)

    def test_batch_fields(self):
        # This test just provides a sanity check, it's not that interesting as it only searches one field.
        results = self.searcher.batch_search(
            ['information retrieval', 'search'], ['q1', 'q2'],
            k=42,
            threads=2,
            fields={'contents': 2.0})

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(results, Dict))
        self.assertTrue(isinstance(results['q1'], List))
        self.assertTrue(isinstance(results['q1'][0], JLuceneSearcherResult))
        self.assertEqual(len(results['q1']), 42)
        self.assertTrue(isinstance(results['q2'], List))
        self.assertTrue(isinstance(results['q2'][0], JLuceneSearcherResult))
        self.assertEqual(len(results['q2']), 42)

    def test_different_similarity(self):
        # qld, default mu
        self.searcher.set_qld()
        self.assertTrue(self.searcher.get_similarity().toString().startswith(
            'LM Dirichlet'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 3.68030, places=5)
        self.assertEqual(hits[9].docid, 'CACM-1927')
        self.assertAlmostEqual(hits[9].score, 2.53240, places=5)

        # bm25, default parameters
        self.searcher.set_bm25()
        self.assertTrue(
            self.searcher.get_similarity().toString().startswith('BM25'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 4.76550, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.21740, places=5)

        # qld, custom mu
        self.searcher.set_qld(100)
        self.assertTrue(self.searcher.get_similarity().toString().startswith(
            'LM Dirichlet'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 6.35580, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2631')
        self.assertAlmostEqual(hits[9].score, 5.18960, places=5)

        # bm25, custom parameters
        self.searcher.set_bm25(0.8, 0.3)
        self.assertTrue(
            self.searcher.get_similarity().toString().startswith('BM25'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 4.86880, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.33320, places=5)

    def test_rm3(self):
        self.searcher.set_rm3()
        self.assertTrue(self.searcher.is_using_rm3())

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 2.18010, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 1.70330, places=5)

        self.searcher.unset_rm3()
        self.assertFalse(self.searcher.is_using_rm3())

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 4.76550, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.21740, places=5)

        self.searcher.set_rm3(fb_docs=4, fb_terms=6, original_query_weight=0.3)
        self.assertTrue(self.searcher.is_using_rm3())

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 2.17190, places=5)
        self.assertEqual(hits[9].docid, 'CACM-1457')
        self.assertAlmostEqual(hits[9].score, 1.43700, places=5)

    def test_doc_int(self):
        # The doc method is overloaded: if input is int, it's assumed to be a Lucene internal docid.
        doc = self.searcher.doc(1)
        self.assertTrue(isinstance(doc, Document))

        # These are all equivalent ways to get the docid.
        self.assertEqual('CACM-0002', doc.id())
        self.assertEqual('CACM-0002', doc.docid())
        self.assertEqual('CACM-0002', doc.get('id'))
        self.assertEqual('CACM-0002',
                         doc.lucene_document().getField('id').stringValue())

        # These are all equivalent ways to get the 'raw' field
        self.assertEqual(186, len(doc.raw()))
        self.assertEqual(186, len(doc.get('raw')))
        self.assertEqual(186, len(doc.lucene_document().get('raw')))
        self.assertEqual(
            186, len(doc.lucene_document().getField('raw').stringValue()))

        # These are all equivalent ways to get the 'contents' field
        self.assertEqual(154, len(doc.contents()))
        self.assertEqual(154, len(doc.get('contents')))
        self.assertEqual(154, len(doc.lucene_document().get('contents')))
        self.assertEqual(
            154, len(doc.lucene_document().getField('contents').stringValue()))

        # Should return None if we request a docid that doesn't exist
        self.assertTrue(self.searcher.doc(314159) is None)

    def test_doc_str(self):
        # The doc method is overloaded: if input is str, it's assumed to be an external collection docid.
        doc = self.searcher.doc('CACM-0002')
        self.assertTrue(isinstance(doc, Document))

        # These are all equivalent ways to get the docid.
        self.assertEqual(doc.lucene_document().getField('id').stringValue(),
                         'CACM-0002')
        self.assertEqual(doc.id(), 'CACM-0002')
        self.assertEqual(doc.docid(), 'CACM-0002')
        self.assertEqual(doc.get('id'), 'CACM-0002')

        # These are all equivalent ways to get the 'raw' field
        self.assertEqual(186, len(doc.raw()))
        self.assertEqual(186, len(doc.get('raw')))
        self.assertEqual(186, len(doc.lucene_document().get('raw')))
        self.assertEqual(
            186, len(doc.lucene_document().getField('raw').stringValue()))

        # These are all equivalent ways to get the 'contents' field
        self.assertEqual(154, len(doc.contents()))
        self.assertEqual(154, len(doc.get('contents')))
        self.assertEqual(154, len(doc.lucene_document().get('contents')))
        self.assertEqual(
            154, len(doc.lucene_document().getField('contents').stringValue()))

        # Should return None if we request a docid that doesn't exist
        self.assertTrue(self.searcher.doc('foo') is None)

    def test_doc_by_field(self):
        self.assertEqual(
            self.searcher.doc('CACM-3134').docid(),
            self.searcher.doc_by_field('id', 'CACM-3134').docid())

        # Should return None if we request a docid that doesn't exist
        self.assertTrue(self.searcher.doc_by_field('foo', 'bar') is None)

    def tearDown(self):
        self.searcher.close()
        os.remove(self.tarball_name)
        shutil.rmtree(self.index_dir)
Beispiel #10
0
    if not searcher:
        exit()

    # Check PRF Flag
    if args.prf_depth > 0 and type(searcher) == FaissSearcher:
        PRF_FLAG = True
        if args.prf_method.lower() == 'avg':
            prfRule = DenseVectorAveragePrf()
        elif args.prf_method.lower() == 'rocchio':
            prfRule = DenseVectorRocchioPrf(args.rocchio_alpha,
                                            args.rocchio_beta)
        # ANCE-PRF is using a new query encoder, so the input to DenseVectorAncePrf is different
        elif args.prf_method.lower() == 'ance-prf' and type(
                query_encoder) == AnceQueryEncoder:
            if os.path.exists(args.sparse_index):
                sparse_searcher = LuceneSearcher(args.sparse_index)
            else:
                sparse_searcher = LuceneSearcher.from_prebuilt_index(
                    args.sparse_index)
            prf_query_encoder = AnceQueryEncoder(
                encoder_dir=args.ance_prf_encoder,
                tokenizer_name=args.tokenizer,
                device=args.device)
            prfRule = DenseVectorAncePrf(prf_query_encoder, sparse_searcher)
        print(f'Running FaissSearcher with {args.prf_method.upper()} PRF...')
    else:
        PRF_FLAG = False

    # build output path
    output_path = args.output
Beispiel #11
0
    parser.add_argument("--bin-width",
                        type=int,
                        help='Width of each bin.',
                        default=50)
    parser.add_argument("--plot",
                        type=str,
                        help='Output file of histogram PDF.')
    parser.add_argument("--output",
                        type=str,
                        help='Prefix of raw count and bin data file.')

    args = parser.parse_args()

    plt.switch_backend('agg')

    searcher = LuceneSearcher(args.index)

    # Determine how many documents to iterate over:
    if args.max:
        num_docs = args.max if args.max < searcher.num_docs else searcher.num_docs
    else:
        num_docs = searcher.num_docs

    print(f'Computing lengths for {num_docs} from {args.index}')
    doclengths = []
    for i in tqdm(range(num_docs)):
        doclengths.append(len(searcher.doc(i).raw().split()))

    doclengths = np.asarray(doclengths)

    # Compute bins:
from pyserini.eval.evaluate_dpr_retrieval import has_answers, SimpleTokenizer

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Convert an TREC run to DPR retrieval result json.')
    parser.add_argument('--topics', required=True, help='topic name')
    parser.add_argument('--index', required=True, help='Anserini Index that contains raw')
    parser.add_argument('--input', required=True, help='Input TREC run file.')
    parser.add_argument('--store-raw', action='store_true', help='Store raw text of passage')
    parser.add_argument('--regex', action='store_true', default=False, help="regex match")
    parser.add_argument('--output', required=True, help='Output DPR Retrieval json file.')
    args = parser.parse_args()

    qas = get_topics(args.topics)

    if os.path.exists(args.index):
        searcher = LuceneSearcher(args.index)
    else:
        searcher = LuceneSearcher.from_prebuilt_index(args.index)
    if not searcher:
        exit()

    retrieval = {}
    tokenizer = SimpleTokenizer()
    with open(args.input) as f_in:
        for line in tqdm(f_in.readlines()):
            question_id, _, doc_id, _, score, _ = line.strip().split()
            question_id = int(question_id)
            question = qas[question_id]['title']
            answers = qas[question_id]['answers']
            if answers[0] == '"':
                answers = answers[1:-1].replace('""', '"')
Beispiel #13
0
class DPRDemo(cmd.Cmd):
    nq_dev_topics = list(search.get_topics('dpr-nq-dev').values())
    trivia_dev_topics = list(search.get_topics('dpr-trivia-dev').values())

    ssearcher = LuceneSearcher.from_prebuilt_index('wikipedia-dpr')
    searcher = ssearcher

    encoder = DprQueryEncoder("facebook/dpr-question_encoder-multiset-base")
    index = 'wikipedia-dpr-multi-bf'
    dsearcher = FaissSearcher.from_prebuilt_index(index, encoder)
    hsearcher = HybridSearcher(dsearcher, ssearcher)

    k = 10
    prompt = '>>> '

    def precmd(self, line):
        if line[0] == '/':
            line = line[1:]
        return line

    def do_help(self, arg):
        print(f'/help    : returns this message')
        print(f'/k [NUM] : sets k (number of hits to return) to [NUM]')
        print(
            f'/mode [MODE] : sets retriever type to [MODE] (one of sparse, dense, hybrid)'
        )
        print(
            f'/random [COLLECTION]: returns results for a random question from the dev subset [COLLECTION] (one of nq, trivia).'
        )

    def do_k(self, arg):
        print(f'setting k = {int(arg)}')
        self.k = int(arg)

    def do_mode(self, arg):
        if arg == "sparse":
            self.searcher = self.ssearcher
        elif arg == "dense":
            self.searcher = self.dsearcher
        elif arg == "hybrid":
            self.searcher = self.hsearcher
        else:
            print(
                f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].'
            )
            return
        print(f'setting retriver = {arg}')

    def do_random(self, arg):
        if arg == "nq":
            topics = self.nq_dev_topics
        elif arg == "trivia":
            topics = self.trivia_dev_topics
        else:
            print(
                f'Collection "{arg}" is invalid. Collection should be one of [nq, trivia].'
            )
            return
        q = random.choice(topics)['title']
        print(f'question: {q}')
        self.default(q)

    def do_EOF(self, line):
        return True

    def default(self, q):
        hits = self.searcher.search(q, self.k)

        for i in range(0, len(hits)):
            raw_doc = None
            if isinstance(self.searcher, LuceneSearcher):
                raw_doc = hits[i].raw
            else:
                doc = self.searcher.doc(hits[i].docid)
                if doc:
                    raw_doc = doc.raw()
            jsondoc = json.loads(raw_doc)
            print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}')
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import json
import sys

# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one).
# Comment these lines out to use a pip-installed one instead.
sys.path.insert(0, './')

from pyserini.search.lucene import LuceneSearcher

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--qrels', type=str, help='qrels file', required=True)
    parser.add_argument('--index',
                        type=str,
                        help='index location',
                        required=True)
    args = parser.parse_args()

    searcher = LuceneSearcher(args.index)
    with open(args.qrels, 'r') as reader:
        for line in reader.readlines():
            arr = line.split('\t')
            doc = json.loads(searcher.doc(arr[2]).raw())['contents']
            print(f'{arr[2]}\t{doc}')
Beispiel #15
0
class TestAnalyzers(unittest.TestCase):
    def setUp(self):
        # Download pre-built CACM index; append a random value to avoid filename clashes.
        r = randint(0, 10000000)
        self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz'
        self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r)
        self.index_dir = 'index{}/'.format(r)

        _, _ = urlretrieve(self.collection_url, self.tarball_name)

        tarball = tarfile.open(self.tarball_name)
        tarball.extractall(self.index_dir)
        tarball.close()
        self.searcher = LuceneSearcher(f'{self.index_dir}lucene-index.cacm')
        self.index_utils = IndexReader(f'{self.index_dir}lucene-index.cacm')

    def test_different_analyzers_are_different(self):
        self.searcher.set_analyzer(get_lucene_analyzer(stemming=False))
        hits_first = self.searcher.search('information retrieval')
        self.searcher.set_analyzer(get_lucene_analyzer())
        hits_second = self.searcher.search('information retrieval')
        self.assertNotEqual(hits_first, hits_second)

    def test_analyze_with_analyzer(self):
        analyzer = get_lucene_analyzer(stemming=False)
        self.assertTrue(isinstance(analyzer, JAnalyzer))
        query = 'information retrieval'
        only_tokenization = JAnalyzerUtils.analyze(analyzer, query)
        token_list = []
        for token in only_tokenization.toArray():
            token_list.append(token)
        self.assertEqual(token_list, ['information', 'retrieval'])

    def test_analysis(self):
        # Default is Porter stemmer
        analyzer = Analyzer(get_lucene_analyzer())
        self.assertTrue(isinstance(analyzer, Analyzer))
        tokens = analyzer.analyze('City buses are running on time.')
        self.assertEqual(tokens, ['citi', 'buse', 'run', 'time'])

        # Specify Porter stemmer explicitly
        analyzer = Analyzer(get_lucene_analyzer(stemmer='porter'))
        self.assertTrue(isinstance(analyzer, Analyzer))
        tokens = analyzer.analyze('City buses are running on time.')
        self.assertEqual(tokens, ['citi', 'buse', 'run', 'time'])

        # Specify Krovetz stemmer explicitly
        analyzer = Analyzer(get_lucene_analyzer(stemmer='krovetz'))
        self.assertTrue(isinstance(analyzer, Analyzer))
        tokens = analyzer.analyze('City buses are running on time.')
        self.assertEqual(tokens, ['city', 'bus', 'running', 'time'])

        # No stemming
        analyzer = Analyzer(get_lucene_analyzer(stemming=False))
        self.assertTrue(isinstance(analyzer, Analyzer))
        tokens = analyzer.analyze('City buses are running on time.')
        self.assertEqual(tokens, ['city', 'buses', 'running', 'time'])

        # No stopword filter, no stemming
        analyzer = Analyzer(
            get_lucene_analyzer(stemming=False, stopwords=False))
        self.assertTrue(isinstance(analyzer, Analyzer))
        tokens = analyzer.analyze('City buses are running on time.')
        self.assertEqual(tokens,
                         ['city', 'buses', 'are', 'running', 'on', 'time'])

        # No stopword filter, with stemming
        analyzer = Analyzer(get_lucene_analyzer(stemming=True,
                                                stopwords=False))
        self.assertTrue(isinstance(analyzer, Analyzer))
        tokens = analyzer.analyze('City buses are running on time.')
        self.assertEqual(tokens, ['citi', 'buse', 'ar', 'run', 'on', 'time'])

    def test_invalid_analyzer_wrapper(self):
        # Invalid JAnalyzer, make sure we get an exception.
        with self.assertRaises(TypeError):
            Analyzer('str')

    def test_invalid_analysis(self):
        # Invalid configuration, make sure we get an exception.
        with self.assertRaises(ValueError):
            Analyzer(get_lucene_analyzer('blah'))

    def tearDown(self):
        self.searcher.close()
        os.remove(self.tarball_name)
        shutil.rmtree(self.index_dir)
Beispiel #16
0
 def test_default_cache(self):
     LuceneSearcher.from_prebuilt_index('cacm')
     self.assertTrue(
         os.path.exists(os.path.expanduser('~/.cache/pyserini/indexes')))
Beispiel #17
0
class TestQueryBuilding(unittest.TestCase):
    def setUp(self):
        # Download pre-built CACM index; append a random value to avoid filename clashes.
        r = randint(0, 10000000)
        self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz'
        self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r)
        self.index_dir = 'index{}/'.format(r)

        filename, headers = urlretrieve(self.collection_url, self.tarball_name)

        tarball = tarfile.open(self.tarball_name)
        tarball.extractall(self.index_dir)
        tarball.close()

        self.searcher = LuceneSearcher(f'{self.index_dir}lucene-index.cacm')

    def testBuildBoostedQuery(self):
        term_query1 = querybuilder.get_term_query('information')
        term_query2 = querybuilder.get_term_query('retrieval')

        boost1 = querybuilder.get_boost_query(term_query1, 2.)
        boost2 = querybuilder.get_boost_query(term_query2, 2.)

        should = querybuilder.JBooleanClauseOccur['should'].value

        boolean_query = querybuilder.get_boolean_query_builder()
        boolean_query.add(boost1, should)
        boolean_query.add(boost2, should)

        bq = boolean_query.build()
        hits1 = self.searcher.search(bq)

        boolean_query2 = querybuilder.get_boolean_query_builder()
        boolean_query2.add(term_query1, should)
        boolean_query2.add(term_query2, should)

        bq2 = boolean_query2.build()
        hits2 = self.searcher.search(bq2)

        for h1, h2 in zip(hits1, hits2):
            self.assertEqual(h1.docid, h2.docid)
            self.assertAlmostEqual(h1.score, h2.score*2, delta=0.001)

        boost3 = querybuilder.get_boost_query(term_query1, 2.)
        boost4 = querybuilder.get_boost_query(term_query2, 3.)

        boolean_query = querybuilder.get_boolean_query_builder()
        boolean_query.add(boost3, should)
        boolean_query.add(boost4, should)

        bq3 = boolean_query.build()
        hits3 = self.searcher.search(bq3)

        for h1, h3 in zip(hits1, hits3):
            self.assertNotEqual(h1.score, h3.score)

    def testTermQuery(self):
        should = querybuilder.JBooleanClauseOccur['should'].value
        query_builder = querybuilder.get_boolean_query_builder()
        query_builder.add(querybuilder.get_term_query('information'), should)
        query_builder.add(querybuilder.get_term_query('retrieval'), should)

        query = query_builder.build()
        hits1 = self.searcher.search(query)
        hits2 = self.searcher.search('information retrieval')

        for h1, h2 in zip(hits1, hits2):
            self.assertEqual(h1.docid, h2.docid)
            self.assertEqual(h1.score, h2.score)

    def testIncompatabilityWithRM3(self):
        should = querybuilder.JBooleanClauseOccur['should'].value
        query_builder = querybuilder.get_boolean_query_builder()
        query_builder.add(querybuilder.get_term_query('information'), should)
        query_builder.add(querybuilder.get_term_query('retrieval'), should)

        query = query_builder.build()
        hits = self.searcher.search(query)
        self.assertEqual(10, len(hits))

        self.searcher.set_rm3()
        self.assertTrue(self.searcher.is_using_rm3())

        with self.assertRaises(NotImplementedError):
            self.searcher.search(query)

    def testTermQuery2(self):
        term_query1 = querybuilder.get_term_query('inform', analyzer=get_lucene_analyzer(stemming=False))
        term_query2 = querybuilder.get_term_query('retriev', analyzer=get_lucene_analyzer(stemming=False))

        should = querybuilder.JBooleanClauseOccur['should'].value

        boolean_query1 = querybuilder.get_boolean_query_builder()
        boolean_query1.add(term_query1, should)
        boolean_query1.add(term_query2, should)

        bq1 = boolean_query1.build()
        hits1 = self.searcher.search(bq1)
        hits2 = self.searcher.search('information retrieval')

        for h1, h2 in zip(hits1, hits2):
            self.assertEqual(h1.docid, h2.docid)
            self.assertEqual(h1.score, h2.score)

    def tearDown(self):
        self.searcher.close()
        os.remove(self.tarball_name)
        shutil.rmtree(self.index_dir)
Beispiel #18
0
def dev_data_loader(file, format, data, rerank, prebuilt, top=1000):
    if rerank:
        if format == 'tsv':
            dev = pd.read_csv(file,
                              sep="\t",
                              names=['qid', 'pid', 'rank'],
                              dtype={
                                  'qid': 'S',
                                  'pid': 'S',
                                  'rank': 'i',
                              })
        elif format == 'trec':
            dev = pd.read_csv(
                file,
                sep="\s+",
                names=['qid', 'q0', 'pid', 'rank', 'score', 'tag'],
                usecols=['qid', 'pid', 'rank'],
                dtype={
                    'qid': 'S',
                    'pid': 'S',
                    'rank': 'i',
                })
        else:
            raise Exception('unknown parameters')
        assert dev['qid'].dtype == object
        assert dev['pid'].dtype == object
        assert dev['rank'].dtype == np.int32
        dev = dev[dev['rank'] <= top]
    else:
        if prebuilt:
            bm25search = LuceneSearcher.from_prebuilt_index(args.index)
        else:
            bm25search = LuceneSearcher(args.index)
        bm25search.set_bm25(0.82, 0.68)
        dev_dic = {"qid": [], "pid": [], "rank": []}
        for topic in tqdm(queries.keys()):
            query_text = queries[topic]['raw']
            bm25_dev = bm25search.search(query_text, args.hits)
            doc_ids = [bm25_result.docid for bm25_result in bm25_dev]
            qid = [topic for _ in range(len(doc_ids))]
            rank = [i for i in range(1, len(doc_ids) + 1)]
            dev_dic['qid'].extend(qid)
            dev_dic['pid'].extend(doc_ids)
            dev_dic['rank'].extend(rank)
        dev = pd.DataFrame(dev_dic)
        dev['rank'].astype(np.int32)

    if data == 'passage':
        dev_qrel = pd.read_csv(
            'tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt',
            sep=" ",
            names=["qid", "q0", "pid", "rel"],
            usecols=['qid', 'pid', 'rel'],
            dtype={
                'qid': 'S',
                'pid': 'S',
                'rel': 'i'
            })
    elif data == 'document':
        dev_qrel = pd.read_csv(
            'tools/topics-and-qrels/qrels.msmarco-doc.dev.txt',
            sep="\t",
            names=["qid", "q0", "pid", "rel"],
            usecols=['qid', 'pid', 'rel'],
            dtype={
                'qid': 'S',
                'pid': 'S',
                'rel': 'i'
            })
    dev = dev.merge(dev_qrel,
                    left_on=['qid', 'pid'],
                    right_on=['qid', 'pid'],
                    how='left')
    dev['rel'] = dev['rel'].fillna(0).astype(np.int32)
    dev = dev.sort_values(['qid', 'pid']).set_index(['qid', 'pid'])

    print(dev.shape)
    print(dev.index.get_level_values('qid').drop_duplicates().shape)
    print(dev.groupby('qid').count().mean())
    print(dev.head(10))
    print(dev.info())

    dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel']

    recall_point = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000]
    recall_curve = {k: [] for k in recall_point}
    for qid, group in tqdm(dev.groupby('qid')):
        group = group.reset_index()
        assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
        total_rel = dev_rel_num.loc[qid]
        query_recall = [0 for k in recall_point]
        for t in group.sort_values('rank').itertuples():
            if t.rel > 0:
                for i, p in enumerate(recall_point):
                    if t.rank <= p:
                        query_recall[i] += 1
        for i, p in enumerate(recall_point):
            if total_rel > 0:
                recall_curve[p].append(query_recall[i] / total_rel)
            else:
                recall_curve[p].append(0.)

    for k, v in recall_curve.items():
        avg = np.mean(v)
        print(f'recall@{k}:{avg}')

    return dev, dev_qrel
Beispiel #19
0
class MsMarcoDemo(cmd.Cmd):
    dev_topics = list(search.get_topics('msmarco-passage-dev-subset').values())

    ssearcher = LuceneSearcher.from_prebuilt_index('msmarco-passage')
    dsearcher = None
    hsearcher = None
    searcher = ssearcher

    k = 10
    prompt = '>>> '

    # https://stackoverflow.com/questions/35213134/command-prefixes-in-python-cli-using-cmd-in-pythons-standard-library
    def precmd(self, line):
        if line[0] == '/':
            line = line[1:]
        return line

    def do_help(self, arg):
        print(f'/help    : returns this message')
        print(f'/k [NUM] : sets k (number of hits to return) to [NUM]')
        print(
            f'/model [MODEL] : sets encoder to use the model [MODEL] (one of tct, ance)'
        )
        print(
            f'/mode [MODE] : sets retriever type to [MODE] (one of sparse, dense, hybrid)'
        )
        print(
            f'/random : returns results for a random question from dev subset')

    def do_k(self, arg):
        print(f'setting k = {int(arg)}')
        self.k = int(arg)

    def do_mode(self, arg):
        if arg == "sparse":
            self.searcher = self.ssearcher
        elif arg == "dense":
            if self.dsearcher is None:
                print(
                    f'Specify model through /model before using dense retrieval.'
                )
                return
            self.searcher = self.dsearcher
        elif arg == "hybrid":
            if self.hsearcher is None:
                print(
                    f'Specify model through /model before using hybrid retrieval.'
                )
                return
            self.searcher = self.hsearcher
        else:
            print(
                f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].'
            )
            return
        print(f'setting retriver = {arg}')

    def do_model(self, arg):
        if arg == "tct":
            encoder = TctColBertQueryEncoder("castorini/tct_colbert-msmarco")
            index = "msmarco-passage-tct_colbert-hnsw"
        elif arg == "ance":
            encoder = AnceQueryEncoder("castorini/ance-msmarco-passage")
            index = "msmarco-passage-ance-bf"
        else:
            print(
                f'Model "{arg}" is invalid. Model should be one of [tct, ance].'
            )
            return

        self.dsearcher = FaissSearcher.from_prebuilt_index(index, encoder)
        self.hsearcher = HybridSearcher(self.dsearcher, self.ssearcher)
        print(f'setting model = {arg}')

    def do_random(self, arg):
        q = random.choice(self.dev_topics)['title']
        print(f'question: {q}')
        self.default(q)

    def do_EOF(self, line):
        return True

    def default(self, q):
        hits = self.searcher.search(q, self.k)

        for i in range(0, len(hits)):
            raw_doc = None
            if isinstance(self.searcher, LuceneSearcher):
                raw_doc = hits[i].raw
            else:
                doc = self.searcher.doc(hits[i].docid)
                if doc:
                    raw_doc = doc.raw()
            jsondoc = json.loads(raw_doc)
            print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}')