text = text.join(toks) if args.batch_size <= 1 and args.threads <= 1: hits = searcher.search(text, args.hits, query_generator=query_generator, fields=fields) results = [(topic_id, hits)] else: batch_topic_ids.append(str(topic_id)) batch_topics.append(text) if (index + 1) % args.batch_size == 0 or \ index == len(topics.keys()) - 1: results = searcher.batch_search( batch_topics, batch_topic_ids, args.hits, args.threads, query_generator=query_generator, fields=fields) results = [(id_, results[id_]) for id_ in batch_topic_ids] batch_topic_ids.clear() batch_topics.clear() else: continue for topic, hits in results: # do rerank if use_prcl and len(hits) > (args.r + args.n): docids = [hit.docid.strip() for hit in hits] scores = [hit.score for hit in hits] scores, docids = ranker.rerank(docids, scores)
order = QUERY_IDS[args.topics] with open(output_path, 'w') as target_file: batch_topics = list() batch_topic_ids = list() for index, (topic_id, text) in enumerate( tqdm(list(query_iterator(topics, order)))): if args.batch_size <= 1 and args.threads <= 1: hits = searcher.search(text, args.hits) results = [(topic_id, hits)] else: batch_topic_ids.append(str(topic_id)) batch_topics.append(text) if (index + 1) % args.batch_size == 0 or \ index == len(topics.keys()) - 1: results = searcher.batch_search(batch_topics, batch_topic_ids, args.hits, args.threads) results = [(id_, results[id_]) for id_ in batch_topic_ids] batch_topic_ids.clear() batch_topics.clear() else: continue for result in results: if args.max_passage: write_result_max_passage(result) else: write_result(result) results.clear()
class TestSearch(unittest.TestCase): def setUp(self): # Download pre-built CACM index; append a random value to avoid filename clashes. r = randint(0, 10000000) self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz' self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r) self.index_dir = 'index{}/'.format(r) filename, headers = urlretrieve(self.collection_url, self.tarball_name) tarball = tarfile.open(self.tarball_name) tarball.extractall(self.index_dir) tarball.close() self.searcher = SimpleSearcher(f'{self.index_dir}lucene-index.cacm') def test_basic(self): self.assertTrue( self.searcher.get_similarity().toString().startswith('BM25')) hits = self.searcher.search('information retrieval') self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(hits, List)) self.assertTrue(isinstance(hits[0], JSimpleSearcherResult)) self.assertEqual(hits[0].docid, 'CACM-3134') self.assertEqual(hits[0].lucene_docid, 3133) self.assertEqual(len(hits[0].contents), 1500) self.assertEqual(len(hits[0].raw), 1532) self.assertAlmostEqual(hits[0].score, 4.76550, places=5) # Test accessing the raw Lucene document and fetching fields from it: self.assertEqual(hits[0].lucene_document.getField('id').stringValue(), 'CACM-3134') self.assertEqual(hits[0].lucene_document.get('id'), 'CACM-3134') # simpler call, same result as above self.assertEqual( len(hits[0].lucene_document.getField('raw').stringValue()), 1532) self.assertEqual(len(hits[0].lucene_document.get('raw')), 1532) # simpler call, same result as above self.assertTrue(isinstance(hits[9], JSimpleSearcherResult)) self.assertEqual(hits[9].docid, 'CACM-2516') self.assertAlmostEqual(hits[9].score, 4.21740, places=5) hits = self.searcher.search('search') self.assertTrue(isinstance(hits[0], JSimpleSearcherResult)) self.assertEqual(hits[0].docid, 'CACM-3058') self.assertAlmostEqual(hits[0].score, 2.85760, places=5) self.assertTrue(isinstance(hits[9], JSimpleSearcherResult)) self.assertEqual(hits[9].docid, 'CACM-3040') self.assertAlmostEqual(hits[9].score, 2.68780, places=5) def test_batch(self): results = self.searcher.batch_search( ['information retrieval', 'search'], ['q1', 'q2'], threads=2) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(results, Dict)) self.assertTrue(isinstance(results['q1'], List)) self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult)) self.assertEqual(results['q1'][0].docid, 'CACM-3134') self.assertAlmostEqual(results['q1'][0].score, 4.76550, places=5) self.assertTrue(isinstance(results['q1'][9], JSimpleSearcherResult)) self.assertEqual(results['q1'][9].docid, 'CACM-2516') self.assertAlmostEqual(results['q1'][9].score, 4.21740, places=5) self.assertTrue(isinstance(results['q2'], List)) self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult)) self.assertEqual(results['q2'][0].docid, 'CACM-3058') self.assertAlmostEqual(results['q2'][0].score, 2.85760, places=5) self.assertTrue(isinstance(results['q2'][9], JSimpleSearcherResult)) self.assertEqual(results['q2'][9].docid, 'CACM-3040') self.assertAlmostEqual(results['q2'][9].score, 2.68780, places=5) def test_basic_k(self): hits = self.searcher.search('information retrieval', k=100) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(hits, List)) self.assertTrue(isinstance(hits[0], JSimpleSearcherResult)) self.assertEqual(len(hits), 100) def test_batch_k(self): results = self.searcher.batch_search( ['information retrieval', 'search'], ['q1', 'q2'], k=100, threads=2) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(results, Dict)) self.assertTrue(isinstance(results['q1'], List)) self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult)) self.assertEqual(len(results['q1']), 100) self.assertTrue(isinstance(results['q2'], List)) self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult)) self.assertEqual(len(results['q2']), 100) def test_basic_fields(self): # This test just provides a sanity check, it's not that interesting as it only searches one field. hits = self.searcher.search('information retrieval', k=42, fields={'contents': 2.0}) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(hits, List)) self.assertTrue(isinstance(hits[0], JSimpleSearcherResult)) self.assertEqual(len(hits), 42) def test_batch_fields(self): # This test just provides a sanity check, it's not that interesting as it only searches one field. results = self.searcher.batch_search( ['information retrieval', 'search'], ['q1', 'q2'], k=42, threads=2, fields={'contents': 2.0}) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(results, Dict)) self.assertTrue(isinstance(results['q1'], List)) self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult)) self.assertEqual(len(results['q1']), 42) self.assertTrue(isinstance(results['q2'], List)) self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult)) self.assertEqual(len(results['q2']), 42) def test_different_similarity(self): # qld, default mu self.searcher.set_qld() self.assertTrue(self.searcher.get_similarity().toString().startswith( 'LM Dirichlet')) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 3.68030, places=5) self.assertEqual(hits[9].docid, 'CACM-1927') self.assertAlmostEqual(hits[9].score, 2.53240, places=5) # bm25, default parameters self.searcher.set_bm25() self.assertTrue( self.searcher.get_similarity().toString().startswith('BM25')) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 4.76550, places=5) self.assertEqual(hits[9].docid, 'CACM-2516') self.assertAlmostEqual(hits[9].score, 4.21740, places=5) # qld, custom mu self.searcher.set_qld(100) self.assertTrue(self.searcher.get_similarity().toString().startswith( 'LM Dirichlet')) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 6.35580, places=5) self.assertEqual(hits[9].docid, 'CACM-2631') self.assertAlmostEqual(hits[9].score, 5.18960, places=5) # bm25, custom parameters self.searcher.set_bm25(0.8, 0.3) self.assertTrue( self.searcher.get_similarity().toString().startswith('BM25')) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 4.86880, places=5) self.assertEqual(hits[9].docid, 'CACM-2516') self.assertAlmostEqual(hits[9].score, 4.33320, places=5) def test_rm3(self): self.searcher.set_rm3() self.assertTrue(self.searcher.is_using_rm3()) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 2.18010, places=5) self.assertEqual(hits[9].docid, 'CACM-2516') self.assertAlmostEqual(hits[9].score, 1.70330, places=5) self.searcher.unset_rm3() self.assertFalse(self.searcher.is_using_rm3()) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 4.76550, places=5) self.assertEqual(hits[9].docid, 'CACM-2516') self.assertAlmostEqual(hits[9].score, 4.21740, places=5) self.searcher.set_rm3(fb_docs=4, fb_terms=6, original_query_weight=0.3) self.assertTrue(self.searcher.is_using_rm3()) hits = self.searcher.search('information retrieval') self.assertEqual(hits[0].docid, 'CACM-3134') self.assertAlmostEqual(hits[0].score, 2.17190, places=5) self.assertEqual(hits[9].docid, 'CACM-1457') self.assertAlmostEqual(hits[9].score, 1.43700, places=5) def test_doc_int(self): # The doc method is overloaded: if input is int, it's assumed to be a Lucene internal docid. doc = self.searcher.doc(1) self.assertTrue(isinstance(doc, Document)) # These are all equivalent ways to get the docid. self.assertEqual('CACM-0002', doc.id()) self.assertEqual('CACM-0002', doc.docid()) self.assertEqual('CACM-0002', doc.get('id')) self.assertEqual('CACM-0002', doc.lucene_document().getField('id').stringValue()) # These are all equivalent ways to get the 'raw' field self.assertEqual(186, len(doc.raw())) self.assertEqual(186, len(doc.get('raw'))) self.assertEqual(186, len(doc.lucene_document().get('raw'))) self.assertEqual( 186, len(doc.lucene_document().getField('raw').stringValue())) # These are all equivalent ways to get the 'contents' field self.assertEqual(154, len(doc.contents())) self.assertEqual(154, len(doc.get('contents'))) self.assertEqual(154, len(doc.lucene_document().get('contents'))) self.assertEqual( 154, len(doc.lucene_document().getField('contents').stringValue())) # Should return None if we request a docid that doesn't exist self.assertTrue(self.searcher.doc(314159) is None) def test_doc_str(self): # The doc method is overloaded: if input is str, it's assumed to be an external collection docid. doc = self.searcher.doc('CACM-0002') self.assertTrue(isinstance(doc, Document)) # These are all equivalent ways to get the docid. self.assertEqual(doc.lucene_document().getField('id').stringValue(), 'CACM-0002') self.assertEqual(doc.id(), 'CACM-0002') self.assertEqual(doc.docid(), 'CACM-0002') self.assertEqual(doc.get('id'), 'CACM-0002') # These are all equivalent ways to get the 'raw' field self.assertEqual(186, len(doc.raw())) self.assertEqual(186, len(doc.get('raw'))) self.assertEqual(186, len(doc.lucene_document().get('raw'))) self.assertEqual( 186, len(doc.lucene_document().getField('raw').stringValue())) # These are all equivalent ways to get the 'contents' field self.assertEqual(154, len(doc.contents())) self.assertEqual(154, len(doc.get('contents'))) self.assertEqual(154, len(doc.lucene_document().get('contents'))) self.assertEqual( 154, len(doc.lucene_document().getField('contents').stringValue())) # Should return None if we request a docid that doesn't exist self.assertTrue(self.searcher.doc('foo') is None) def test_doc_by_field(self): self.assertEqual( self.searcher.doc('CACM-3134').docid(), self.searcher.doc_by_field('id', 'CACM-3134').docid()) # Should return None if we request a docid that doesn't exist self.assertTrue(self.searcher.doc_by_field('foo', 'bar') is None) def tearDown(self): self.searcher.close() os.remove(self.tarball_name) shutil.rmtree(self.index_dir)
queries = {} with open(args.input) as f: data = json.load(f) for query in data["data"]: queries[query["id"]] = Query(query["id"], query["question"], query["answers"]) # Find top documents for each query ranked_queries = {} if args.batch_size <= 1 and args.threads <= 1: for qid, q in queries.items(): hits = searcher.search(q.question, 1000) ranked_queries[q] = hits else: for qs in batch(list(queries.values()), args.batch_size): hits = searcher.batch_search([q.question for q in qs], [q.id for q in qs], 1000, args.threads) hits = {queries[q]: v for q, v in hits.items()} ranked_queries.update(hits) output_dict = {} for q, hits in ranked_queries.items(): output_dict[q.id] = {"question": q.question, "answers": q.answers, "contexts": []} for hit in hits: docid = hit.docid.strip() ctx = json.loads(searcher.doc(docid).raw())['contents'] out = {'docid': docid, 'score': hit.score, 'text': ctx} output_dict[q.id]["contexts"].append(out) with open(args.output, "w+") as f: json.dump(output_dict, f, indent=4)
start_time) / (line_number + 1) print( f'Retrieving query {line_number} ({time_per_query:0.3f} s/query)', flush=True) for rank in range(len(hits)): docno = hits[rank].docid fout.write('{}\t{}\t{}\n'.format(qid, docno, rank + 1)) else: qids = [] queries = [] result_dict = {} for line_number, line in enumerate( open(args.queries, 'r', encoding='utf8')): qid, query = line.strip().split('\t') qids.append(qid) queries.append(query) results = searcher.batch_search(queries, qids, args.hits, args.threads) with open(args.output, 'w') as fout: for qid in qids: hits = results.get(qid) for rank in range(len(hits)): docno = hits[rank].docid fout.write(f'{qid}\t{docno}\t{rank+1}\n') total_time = (time.time() - total_start_time) print(f'Total retrieval time: {total_time:0.3f} s') print('Done!')
class PyseriniRanker(Component): def __init__( self, index_folder: str, n_threads: int = 1, top_n: int = 5, text_column_name: str = "contents", return_scores: bool = False, *args, **kwargs, ): self.searcher = SimpleSearcher(str(expand_path(index_folder))) self.n_threads = n_threads self.top_n = top_n self.text_column_name = text_column_name self.return_scores = return_scores def __call__(self, questions: List[str]) -> Tuple[List[Any], List[float]]: docs_batch = [] scores_batch = [] _doc_ids_batch = [] if len(questions) == 1: for question in questions: res = self.searcher.search(question, self.top_n) docs, doc_ids, scores = self._processing_search_result(res) docs_batch.append(docs) scores_batch.append(scores) _doc_ids_batch.append(doc_ids) else: n_batches = len(questions) // self.n_threads + int( len(questions) % self.n_threads > 0) for i in range(n_batches): questions_cur = questions[i * self.n_threads:(i + 1) * self.n_threads] qids_cur = list(range(len(questions_cur))) res_batch = self.searcher.batch_search(questions_cur, qids_cur, self.top_n, self.n_threads) for qid in qids_cur: res = res_batch.get(qid) docs, doc_ids, scores = self._processing_search_result(res) docs_batch.append(docs) scores_batch.append(scores) _doc_ids_batch.append(doc_ids) logger.debug(f"found docs {_doc_ids_batch}") if self.return_scores: return docs_batch, scores_batch else: return docs_batch @staticmethod def _processing_search_result(res): docs = [] doc_ids = [] scores = [] for elem in res: doc = json.loads(elem.raw) score = elem.score if doc and isinstance(doc, dict): docs.append(doc.get("contents", "")) doc_ids.append(doc.get("id", "")) scores.append(score) return docs, doc_ids, scores
class BaselineBM25(): def __init__( self, k, index_loc='../../anserini/indexes/lucene-wapost.v2.pos+docvectors+raw' ): self.utils = Utils() # Make sure you have produced this lucene index before self.index_loc = index_loc self.searcher = SimpleSearcher(self.index_loc) self.k = k # number of hits to return self.searcher.set_bm25(k1=0.9, b=0.4) # BM25 params #searcher.set_rm3(10, 10, 0.5) # relevance feedback self.batch_hits = {} self.topics = get_topics('core18') self.query_ids = [str(id) for id in self.topics.keys()] self.queries = [topic['title'] for topic in self.topics.values()] self.doc_ids = {} self.scores = {} def rank(self): print("Ranking for all topics in progress ...") # Perform batch search on all queries # hits contains: docid, retrieval score, and document content self.batch_hits = self.searcher.batch_search(self.queries, self.query_ids, k=self.k) # Inspect results for first query #print("Scores for first query:") #self.utils.print_top_n_results(self.batch_hits[self.query_ids[0]], 10) # Produce a file suitable to be used with trec-eval self.doc_ids = { query_id: [hit.docid for hit in hits] for query_id, hits in self.batch_hits.items() } self.scores = { query_id: [hit.score for hit in hits] for query_id, hits in self.batch_hits.items() } run_name = f"BASELINE-N{self.k}" self.utils.write_rankings(self.query_ids, self.doc_ids, self.scores, run_name) def get_topics(self): return self.topics def get_batch_hits(self): return self.batch_hits def get_index_loc(self): return self.index_loc def get_query_ids(self): return self.query_ids def get_queries(self): return self.queries def get_doc_ids(self): return self.doc_ids def get_scores(self): return self.scores