コード例 #1
0
 def test_covid_round1(self):
     topics = pysearch.get_topics('covid_round1_udel')
     self.assertEqual(len(topics), 30)
     self.assertEqual('coronavirus origin origin COVID-19',
                      topics[1]['query'])
     self.assertEqual(
         'coronavirus remdesivir remdesivir effective treatment COVID-19',
         topics[30]['query'])
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
コード例 #2
0
 def test_covid_round2(self):
     topics = pysearch.get_topics('covid_round2_udel')
     self.assertEqual(len(topics), 35)
     self.assertEqual('coronavirus origin origin COVID-19',
                      topics[1]['query'])
     self.assertEqual(
         'coronavirus public datasets public datasets COVID-19',
         topics[35]['query'])
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
コード例 #3
0
 def test_msmarco_passage(self):
     topics = pysearch.get_topics('msmarco_passage_dev_subset')
     self.assertEqual(len(topics), 6980)
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
コード例 #4
0
 def test_covid_round1(self):
     topics = pysearch.get_topics('covid_round1')
     self.assertEqual(len(topics), 30)
     self.assertEqual('coronavirus origin', topics[1]['query'])
     self.assertEqual('coronavirus remdesivir', topics[30]['query'])
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
コード例 #5
0
 def test_car20(self):
     topics = pysearch.get_topics('car17v2.0_benchmarkY1test')
     self.assertEqual(len(topics), 2254)
     self.assertFalse(isinstance(next(iter(topics.keys())), int))
コード例 #6
0
 def test_msmarco_doc(self):
     topics = pysearch.get_topics('msmarco_doc_dev')
     self.assertEqual(len(topics), 5193)
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
コード例 #7
0
 def test_robust04(self):
     topics = pysearch.get_topics('robust04')
     self.assertEqual(len(topics), 250)
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
コード例 #8
0
 def test_robust04(self):
     topics = pysearch.get_topics('robust04')
     self.assertEqual(len(topics), 250)
コード例 #9
0
 def test_msmarco_doc(self):
     topics = pysearch.get_topics('msmarco_doc_dev')
     self.assertEqual(len(topics), 5193)
コード例 #10
0
 def test_msmarco_passage(self):
     topics = pysearch.get_topics('msmarco_passage_dev_subset')
     self.assertEqual(len(topics), 6980)
コード例 #11
0
 def test_car20(self):
     topics = pysearch.get_topics('car17v2.0_benchmarkY1test')
     self.assertEqual(len(topics), 2254)
コード例 #12
0
 def test_car15(self):
     topics = pysearch.get_topics('car17v1.5_benchmarkY1test')
     self.assertEqual(len(topics), 2125)
コード例 #13
0
 def test_core18(self):
     topics = pysearch.get_topics('core18')
     self.assertEqual(len(topics), 50)
コード例 #14
0
 def __post_init__(self):
     self.topics = pysearch.get_topics(self.collection_name)
コード例 #15
0
 def test_core18(self):
     topics = pysearch.get_topics('core18')
     self.assertEqual(len(topics), 50)
     self.assertTrue(isinstance(next(iter(topics.keys())), int))
コード例 #16
0
import numpy as np


def make_run_file(file, topics, searcher, w_bm25, w_rnp):
    probTrue = np.load('trueProbs_d2v.npy',allow_pickle='TRUE').item()
    with open(file, 'w') as runfile:
        cnt = 0
        print('Running {} queries in total'.format(len(topics)))
        for id in topics:
            query = topics[id]['title'].encode('utf-8')
            hits = searcher.search(query, 10)
            for i in range(0, len(hits)):
                doc_id = hits[i].docid

                bm25_score = hits[i].score
                real_news_prob = probTrue[str(doc_id)]

                score = w_bm25 * bm25_score + w_rnp * real_news_prob

                _ = runfile.write('{} Q0 {} {} {:.6f} Anserini\n'.format(id, hits[i].docid, i+1, score))
                cnt += 1
                if cnt % 100 == 0:
                	print(f'{cnt} queries completed')

if __name__ == "__main__":
	topics = pysearch.get_topics('robust04')
	searcher = pysearch.SimpleSearcher('robust_index')


	make_run_file('run.fnc-reranker.txt', topics , searcher, 0.5, 0.5)
コード例 #17
0
from es import ElasticSearcher
from elasticsearch import Elasticsearch
from pprint import pprint
from pyserini.search import pysearch
from tqdm.autonotebook import tqdm

# open topic xml
# for each topic
# - run query with es
# - store result
print(text2art('COVID-19 Browser'))
collection_name = 'covid_round3_udel'
embedder = Embedder()
es_searcher = ElasticSearcher(index_name='lucene-index-cord19-2020-05-19-bm25',
                              size=1000)
topics = pysearch.get_topics(collection_name)

bar = tqdm(topics)

dfs = []

for topicid in bar:
    query = topics[topicid]['query']
    bar.set_description(f"[{topicid}]{query[:40]}...")
    query_emb = embedder([query])[0].tolist()
    res = es_searcher(query_emb, topic=topicid)
    df = pd.DataFrame.from_records(res.json)
    df['topicid'] = topicid
    dfs.append(df)

final_df = pd.concat(dfs)