예제 #1
0
    def test_car20(self):
        topics = search.get_topics('car17v2.0-benchmarkY1test')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 2254)
        self.assertFalse(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('car17v2.0-benchmarkY1test')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 2254)
        self.assertFalse(isinstance(next(iter(qrels.keys())), int))
예제 #2
0
    def test_clef2006_fr(self):
        topics = search.get_topics('clef2006-fr')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 49)
        self.assertTrue(isinstance(next(iter(topics.keys())), str))

        qrels = search.get_qrels('clef2006-fr')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 49)
        self.assertTrue(isinstance(next(iter(qrels.keys())), str))
예제 #3
0
    def test_trec2004_terabyte(self):
        topics = search.get_topics('trec2004-terabyte')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 50)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('trec2004-terabyte')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 49)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #4
0
    def test_covid_round5(self):
        topics = search.get_topics('covid-round5')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 50)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('covid-round5')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 50)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #5
0
 def test_covid_round4_udel(self):
     topics = search.get_topics('covid-round4-udel')
     self.assertIsNotNone(topics)
     self.assertEqual(len(topics), 45)
     self.assertEqual('coronavirus origin origin COVID-19',
                      topics[1]['query'])
     self.assertEqual(
         'coronavirus mental health impact COVID-19 pandemic impacted mental health',
         topics[45]['query'])
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
예제 #6
0
    def test_mb12(self):
        topics = search.get_topics('mb12')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 60)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('mb12')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 59)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #7
0
 def test_covid_round2_udel(self):
     topics = search.get_topics('covid-round2-udel')
     self.assertIsNotNone(topics)
     self.assertEqual(len(topics), 35)
     self.assertEqual('coronavirus origin origin COVID-19',
                      topics[1]['query'])
     self.assertEqual(
         'coronavirus public datasets public datasets COVID-19',
         topics[35]['query'])
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
예제 #8
0
 def test_covid_round3_udel(self):
     topics = search.get_topics('covid-round3-udel')
     self.assertIsNotNone(topics)
     self.assertEqual(len(topics), 40)
     self.assertEqual('coronavirus origin origin COVID-19',
                      topics[1]['query'])
     self.assertEqual(
         'coronavirus mutations observed mutations SARS-CoV-2 genome mutations',
         topics[40]['query'])
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
예제 #9
0
    def test_fire2012_en(self):
        topics = search.get_topics('fire2012-en')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 50)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('fire2012-en')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 50)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #10
0
 def test_covid_round1_udel(self):
     topics = search.get_topics('covid-round1-udel')
     self.assertIsNotNone(topics)
     self.assertEqual(len(topics), 30)
     self.assertEqual('coronavirus origin origin COVID-19',
                      topics[1]['query'])
     self.assertEqual(
         'coronavirus remdesivir remdesivir effective treatment COVID-19',
         topics[30]['query'])
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
예제 #11
0
    def test_wt10g(self):
        topics = search.get_topics('wt10g')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 100)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('wt10g')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 100)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #12
0
    def test_robust04(self):
        topics = search.get_topics('robust04')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 250)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('robust04')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 249)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #13
0
    def test_ntcir8_zh(self):
        topics = search.get_topics('ntcir8-zh')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 73)
        self.assertTrue(isinstance(next(iter(topics.keys())), str))

        qrels = search.get_qrels('ntcir8-zh')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 100)
        self.assertTrue(isinstance(next(iter(qrels.keys())), str))
예제 #14
0
    def test_trec3_adhoc(self):
        topics = search.get_topics('trec3-adhoc')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 50)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('trec3-adhoc')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 50)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #15
0
    def test_msmarco_v2_passage_dev2(self):
        topics = search.get_topics('msmarco-v2-passage-dev2')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 4281)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        topics = search.get_topics('msmarco-v2-passage-dev2-unicoil')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 4281)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        topics = search.get_topics('msmarco-v2-passage-dev2-unicoil-noexp')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 4281)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('msmarco-v2-passage-dev2')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 4281)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #16
0
    def test_msmarco_v2_doc_dev(self):
        topics = search.get_topics('msmarco-v2-doc-dev')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 4552)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        topics = search.get_topics('msmarco-v2-doc-dev-unicoil')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 4552)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        topics = search.get_topics('msmarco-v2-doc-dev-unicoil-noexp')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 4552)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('msmarco-v2-doc-dev')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 4552)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #17
0
    def test_dl19_passage(self):
        topics = search.get_topics('dl19-passage')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 43)
        self.assertFalse(isinstance(next(iter(topics.keys())), str))

        topics = search.get_topics('dl19-passage-unicoil')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 43)
        self.assertFalse(isinstance(next(iter(topics.keys())), str))

        topics = search.get_topics('dl19-passage-unicoil-noexp')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 43)
        self.assertFalse(isinstance(next(iter(topics.keys())), str))

        qrels = search.get_qrels('dl19-passage')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 43)
        self.assertFalse(isinstance(next(iter(qrels.keys())), str))
예제 #18
0
    def test_trec_topicreader(self):
        # Running from command-line, we're in root of repo, but running in IDE, we're in tests/
        path = 'tools/topics-and-qrels/topics.robust04.txt'
        if not os.path.exists(path):
            path = f'../{path}'

        self.assertTrue(os.path.exists(path))
        topics = search.get_topics_with_reader('io.anserini.search.topicreader.TrecTopicReader', path)
        self.assertEqual(len(topics), 250)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        self.assertEqual(search.get_topics('robust04'), topics)
예제 #19
0
    def test_covid_round3(self):
        topics = search.get_topics('covid-round3')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 40)
        self.assertEqual('coronavirus origin', topics[1]['query'])
        self.assertEqual('coronavirus mutations', topics[40]['query'])
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('covid-round3')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 40)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #20
0
    def test_tsv_int_topicreader(self):
        # Running from command-line, we're in root of repo, but running in IDE, we're in tests/
        path = 'tools/topics-and-qrels/topics.msmarco-doc.dev.txt'
        if not os.path.exists(path):
            path = f'../{path}'

        self.assertTrue(os.path.exists(path))
        topics = search.get_topics_with_reader('io.anserini.search.topicreader.TsvIntTopicReader', path)
        self.assertEqual(len(topics), 5193)
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        self.assertEqual(search.get_topics('msmarco_doc_dev'), topics)
예제 #21
0
    def test_covid_round2(self):
        topics = search.get_topics('covid-round2')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 35)
        self.assertEqual('coronavirus origin', topics[1]['query'])
        self.assertEqual('coronavirus public datasets', topics[35]['query'])
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('covid-round2')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 35)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #22
0
    def test_covid_round4(self):
        topics = search.get_topics('covid-round4')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 45)
        self.assertEqual('coronavirus origin', topics[1]['query'])
        self.assertEqual('coronavirus mental health impact',
                         topics[45]['query'])
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('covid-round4')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 45)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #23
0
    def test_trec2019_bl(self):
        topics = search.get_topics('trec2019-bl')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 60)
        self.assertEqual('d7d906991e2883889f850de9ae06655e',
                         topics[870]['title'])
        self.assertEqual('0d7f5e24cafc019265d3ee4b9745e7ea',
                         topics[829]['title'])
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('trec2019-bl')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 57)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #24
0
    def test_trec2018_bl(self):
        topics = search.get_topics('trec2018-bl')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 50)
        self.assertEqual('fef0f232a9bd94bdb96bac48c7705503',
                         topics[393]['title'])
        self.assertEqual('a1c41a70-35c7-11e3-8a0e-4e2cf80831fc',
                         topics[825]['title'])
        self.assertTrue(isinstance(next(iter(topics.keys())), int))

        qrels = search.get_qrels('trec2018-bl')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 50)
        self.assertTrue(isinstance(next(iter(qrels.keys())), int))
예제 #25
0
    def test_dl20(self):
        topics = search.get_topics('dl20')
        self.assertIsNotNone(topics)
        self.assertEqual(len(topics), 200)
        self.assertFalse(isinstance(next(iter(topics.keys())), str))

        qrels = search.get_qrels('dl20-doc')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 45)
        self.assertFalse(isinstance(next(iter(qrels.keys())), str))

        qrels = search.get_qrels('dl20-passage')
        self.assertIsNotNone(qrels)
        self.assertEqual(len(qrels), 54)
        self.assertFalse(isinstance(next(iter(qrels.keys())), str))
예제 #26
0
def main():
    try:
        # Location of the generated index
        index_loc = "indexes/msmarco-passage/lucene-index-msmarco"

        # Create a searcher object
        searcher = SimpleSearcher(index_loc)
        # Set the active scorer to BM25
        searcher.set_bm25(k1=0.9, b=0.4)
        # Fetch 3 results for the given test query
        results = searcher.search('this is a test query', k=3)
        # For all results print the docid and the score
        expected = ['5578280', '2016011', '7004677']
        docids = [x.docid for x in results]
        if expected != docids:
            raise Exception('Test query results do not match expected:',
                            expected, '(expecteD)', docids, '(actual)')
        # IndexReader can give information about the index
        indexer = IndexReader(index_loc)
        if indexer.stats()['total_terms'] != 352316036:
            raise Exception(
                'There are an unexpected number of terms in your index set, perhaps something went wrong while downloading and indexing the dataset?'
            )
        topics = get_topics("msmarco-passage-dev-subset")
        if topics == {}:
            raise Exception(
                'Could not find msmarco-passage-dev-subset... Best approach is to retry indexing the dataset.'
            )
        first_query = topics[list(topics.keys())[0]]['title']
        if first_query != "why do people grind teeth in sleep":
            raise Exception(
                'Found a different first query than expected in the dataset. Did you download the right dataset?'
            )
        # Using the pyserini tokenizer/stemmer/etc. to create queries from scratch
        # Using the pyserini tokenizer/stemmer/etc. to create queries from scratch
        query = "This is a test query in which things are tested. Found using www.google.com of course!"
        # Tokenizing in pyserini is called Analyzing
        output = indexer.analyze(query)
        if len(output) != 9:
            raise Exception(
                'Tokenizer is not working correctly, something is probably wrong in Anserini. Perhaps try to install Anserini again.'
            )
    except Exception as inst:
        print('ERROR: something went wrong in the installation')
        print(inst)
    else:
        print("INSTALLATION OK")
예제 #27
0
 def __init__(
     self,
     k,
     index_loc='../../anserini/indexes/lucene-wapost.v2.pos+docvectors+raw'
 ):
     self.utils = Utils()
     # Make sure you have produced this lucene index before
     self.index_loc = index_loc
     self.searcher = SimpleSearcher(self.index_loc)
     self.k = k  # number of hits to return
     self.searcher.set_bm25(k1=0.9, b=0.4)  # BM25 params
     #searcher.set_rm3(10, 10, 0.5)  # relevance feedback
     self.batch_hits = {}
     self.topics = get_topics('core18')
     self.query_ids = [str(id) for id in self.topics.keys()]
     self.queries = [topic['title'] for topic in self.topics.values()]
     self.doc_ids = {}
     self.scores = {}
예제 #28
0
 def from_topics(cls, topics_path: str):
     if os.path.exists(topics_path):
         if topics_path.endswith('.json'):
             with open(topics_path, 'r') as f:
                 topics = json.load(f)
         elif topics_path.endswith('.tsv'):
             topics = get_topics_with_reader(
                 'io.anserini.search.topicreader.TsvIntTopicReader',
                 topics_path)
         else:
             raise NotImplementedError(
                 f"Not sure how to parse {topics_path}. Please specify the file extension."
             )
     else:
         topics = get_topics(topics_path)
     if not topics:
         raise FileNotFoundError(f'Topic {topics_path} Not Found')
     order = QueryIterator.get_predefined_order(topics_path)
     return cls(topics, order)
예제 #29
0
    def test_simple_fusion_searcher(self):
        index_dirs = [
            'indexes/lucene-index-cord19-abstract-2020-05-01/',
            'indexes/lucene-index-cord19-full-text-2020-05-01/',
            'indexes/lucene-index-cord19-paragraph-2020-05-01/'
        ]

        searcher = SimpleFusionSearcher(index_dirs, method=FusionMethod.RRF)

        runs, topics = [], get_topics('covid-round2')
        for topic in tqdm(sorted(topics.keys())):
            query = topics[topic]['question'] + ' ' + topics[topic]['query']
            hits = searcher.search(query,
                                   k=10000,
                                   query_generator=None,
                                   strip_segment_id=True,
                                   remove_dups=True)
            docid_score_pair = [(hit.docid, hit.score) for hit in hits]
            run = TrecRun.from_search_results(docid_score_pair, topic=topic)
            runs.append(run)

        all_topics_run = TrecRun.concat(runs)
        all_topics_run.save_to_txt(output_path='runs/fused.txt',
                                   tag='reciprocal_rank_fusion_k=60')

        # Only keep topic, docid, and rank. Scores may be slightly different due to floating point precision issues and underlying lib versions.
        # TODO: We should probably do this in Python as opposed to calling out to shell for better portability.
        # This has also proven to be a somewhat brittle test, see https://github.com/castorini/pyserini/issues/947
        # A stopgap for above issue, we're restricting comparison to only top-100 ranks.
        os.system(
            """awk '$4 <= 100 {print $1" "$3" "$4}' runs/fused.txt > runs/this.txt"""
        )
        os.system(
            """awk '$4 <= 100 {print $1" "$3" "$4}' runs/anserini.covid-r2.fusion1.txt > runs/that.txt"""
        )

        self.assertTrue(filecmp.cmp('runs/this.txt', 'runs/that.txt'))
예제 #30
0
    def test_simple_fusion_searcher(self):
        index_dirs = ['indexes/lucene-index-cord19-abstract-2020-05-01/',
                      'indexes/lucene-index-cord19-full-text-2020-05-01/',
                      'indexes/lucene-index-cord19-paragraph-2020-05-01/']

        searcher = SimpleFusionSearcher(index_dirs, method=FusionMethod.RRF)

        runs, topics = [], get_topics('covid-round2')
        for topic in tqdm(sorted(topics.keys())):
            query = topics[topic]['question'] + ' ' + topics[topic]['query']
            hits = searcher.search(query, k=10000, query_generator=None, strip_segment_id=True, remove_dups=True)
            docid_score_pair = [(hit.docid, hit.score) for hit in hits]
            run = TrecRun.from_search_results(docid_score_pair, topic=topic)
            runs.append(run)

        all_topics_run = TrecRun.concat(runs)
        all_topics_run.save_to_txt(output_path='runs/fused.txt', tag='reciprocal_rank_fusion_k=60')

        # Only keep topic, docid and rank. Scores have different floating point precisions.
        # TODO: We should probably do this in Python as opposed to calling out to shell for better portability.
        os.system("""awk '{print $1" "$3" "$4}' runs/fused.txt > runs/this.txt""")
        os.system("""awk '{print $1" "$3" "$4}' runs/anserini.covid-r2.fusion1.txt > runs/that.txt""")

        self.assertTrue(filecmp.cmp('runs/this.txt', 'runs/that.txt'))