Exemplo n.º 1
0
                text = text.join(toks)
            if args.batch_size <= 1 and args.threads <= 1:
                hits = searcher.search(text,
                                       args.hits,
                                       query_generator=query_generator,
                                       fields=fields)
                results = [(topic_id, hits)]
            else:
                batch_topic_ids.append(str(topic_id))
                batch_topics.append(text)
                if (index + 1) % args.batch_size == 0 or \
                        index == len(topics.keys()) - 1:
                    results = searcher.batch_search(
                        batch_topics,
                        batch_topic_ids,
                        args.hits,
                        args.threads,
                        query_generator=query_generator,
                        fields=fields)
                    results = [(id_, results[id_]) for id_ in batch_topic_ids]
                    batch_topic_ids.clear()
                    batch_topics.clear()
                else:
                    continue

            for topic, hits in results:
                # do rerank
                if use_prcl and len(hits) > (args.r + args.n):
                    docids = [hit.docid.strip() for hit in hits]
                    scores = [hit.score for hit in hits]
                    scores, docids = ranker.rerank(docids, scores)
Exemplo n.º 2
0
        order = QUERY_IDS[args.topics]

    with open(output_path, 'w') as target_file:
        batch_topics = list()
        batch_topic_ids = list()
        for index, (topic_id, text) in enumerate(
                tqdm(list(query_iterator(topics, order)))):
            if args.batch_size <= 1 and args.threads <= 1:
                hits = searcher.search(text, args.hits)
                results = [(topic_id, hits)]
            else:
                batch_topic_ids.append(str(topic_id))
                batch_topics.append(text)
                if (index + 1) % args.batch_size == 0 or \
                        index == len(topics.keys()) - 1:
                    results = searcher.batch_search(batch_topics,
                                                    batch_topic_ids, args.hits,
                                                    args.threads)
                    results = [(id_, results[id_]) for id_ in batch_topic_ids]
                    batch_topic_ids.clear()
                    batch_topics.clear()
                else:
                    continue

            for result in results:
                if args.max_passage:
                    write_result_max_passage(result)
                else:
                    write_result(result)
            results.clear()
Exemplo n.º 3
0
class TestSearch(unittest.TestCase):
    def setUp(self):
        # Download pre-built CACM index; append a random value to avoid filename clashes.
        r = randint(0, 10000000)
        self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz'
        self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r)
        self.index_dir = 'index{}/'.format(r)

        filename, headers = urlretrieve(self.collection_url, self.tarball_name)

        tarball = tarfile.open(self.tarball_name)
        tarball.extractall(self.index_dir)
        tarball.close()

        self.searcher = SimpleSearcher(f'{self.index_dir}lucene-index.cacm')

    def test_basic(self):
        self.assertTrue(
            self.searcher.get_similarity().toString().startswith('BM25'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(hits, List))

        self.assertTrue(isinstance(hits[0], JSimpleSearcherResult))
        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertEqual(hits[0].lucene_docid, 3133)
        self.assertEqual(len(hits[0].contents), 1500)
        self.assertEqual(len(hits[0].raw), 1532)
        self.assertAlmostEqual(hits[0].score, 4.76550, places=5)

        # Test accessing the raw Lucene document and fetching fields from it:
        self.assertEqual(hits[0].lucene_document.getField('id').stringValue(),
                         'CACM-3134')
        self.assertEqual(hits[0].lucene_document.get('id'),
                         'CACM-3134')  # simpler call, same result as above
        self.assertEqual(
            len(hits[0].lucene_document.getField('raw').stringValue()), 1532)
        self.assertEqual(len(hits[0].lucene_document.get('raw')),
                         1532)  # simpler call, same result as above

        self.assertTrue(isinstance(hits[9], JSimpleSearcherResult))
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.21740, places=5)

        hits = self.searcher.search('search')

        self.assertTrue(isinstance(hits[0], JSimpleSearcherResult))
        self.assertEqual(hits[0].docid, 'CACM-3058')
        self.assertAlmostEqual(hits[0].score, 2.85760, places=5)

        self.assertTrue(isinstance(hits[9], JSimpleSearcherResult))
        self.assertEqual(hits[9].docid, 'CACM-3040')
        self.assertAlmostEqual(hits[9].score, 2.68780, places=5)

    def test_batch(self):
        results = self.searcher.batch_search(
            ['information retrieval', 'search'], ['q1', 'q2'], threads=2)

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(results, Dict))

        self.assertTrue(isinstance(results['q1'], List))
        self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult))
        self.assertEqual(results['q1'][0].docid, 'CACM-3134')
        self.assertAlmostEqual(results['q1'][0].score, 4.76550, places=5)

        self.assertTrue(isinstance(results['q1'][9], JSimpleSearcherResult))
        self.assertEqual(results['q1'][9].docid, 'CACM-2516')
        self.assertAlmostEqual(results['q1'][9].score, 4.21740, places=5)

        self.assertTrue(isinstance(results['q2'], List))
        self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult))
        self.assertEqual(results['q2'][0].docid, 'CACM-3058')
        self.assertAlmostEqual(results['q2'][0].score, 2.85760, places=5)

        self.assertTrue(isinstance(results['q2'][9], JSimpleSearcherResult))
        self.assertEqual(results['q2'][9].docid, 'CACM-3040')
        self.assertAlmostEqual(results['q2'][9].score, 2.68780, places=5)

    def test_basic_k(self):
        hits = self.searcher.search('information retrieval', k=100)

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(hits, List))
        self.assertTrue(isinstance(hits[0], JSimpleSearcherResult))
        self.assertEqual(len(hits), 100)

    def test_batch_k(self):
        results = self.searcher.batch_search(
            ['information retrieval', 'search'], ['q1', 'q2'],
            k=100,
            threads=2)

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(results, Dict))
        self.assertTrue(isinstance(results['q1'], List))
        self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult))
        self.assertEqual(len(results['q1']), 100)
        self.assertTrue(isinstance(results['q2'], List))
        self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult))
        self.assertEqual(len(results['q2']), 100)

    def test_basic_fields(self):
        # This test just provides a sanity check, it's not that interesting as it only searches one field.
        hits = self.searcher.search('information retrieval',
                                    k=42,
                                    fields={'contents': 2.0})

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(hits, List))
        self.assertTrue(isinstance(hits[0], JSimpleSearcherResult))
        self.assertEqual(len(hits), 42)

    def test_batch_fields(self):
        # This test just provides a sanity check, it's not that interesting as it only searches one field.
        results = self.searcher.batch_search(
            ['information retrieval', 'search'], ['q1', 'q2'],
            k=42,
            threads=2,
            fields={'contents': 2.0})

        self.assertEqual(3204, self.searcher.num_docs)
        self.assertTrue(isinstance(results, Dict))
        self.assertTrue(isinstance(results['q1'], List))
        self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult))
        self.assertEqual(len(results['q1']), 42)
        self.assertTrue(isinstance(results['q2'], List))
        self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult))
        self.assertEqual(len(results['q2']), 42)

    def test_different_similarity(self):
        # qld, default mu
        self.searcher.set_qld()
        self.assertTrue(self.searcher.get_similarity().toString().startswith(
            'LM Dirichlet'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 3.68030, places=5)
        self.assertEqual(hits[9].docid, 'CACM-1927')
        self.assertAlmostEqual(hits[9].score, 2.53240, places=5)

        # bm25, default parameters
        self.searcher.set_bm25()
        self.assertTrue(
            self.searcher.get_similarity().toString().startswith('BM25'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 4.76550, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.21740, places=5)

        # qld, custom mu
        self.searcher.set_qld(100)
        self.assertTrue(self.searcher.get_similarity().toString().startswith(
            'LM Dirichlet'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 6.35580, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2631')
        self.assertAlmostEqual(hits[9].score, 5.18960, places=5)

        # bm25, custom parameters
        self.searcher.set_bm25(0.8, 0.3)
        self.assertTrue(
            self.searcher.get_similarity().toString().startswith('BM25'))

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 4.86880, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.33320, places=5)

    def test_rm3(self):
        self.searcher.set_rm3()
        self.assertTrue(self.searcher.is_using_rm3())

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 2.18010, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 1.70330, places=5)

        self.searcher.unset_rm3()
        self.assertFalse(self.searcher.is_using_rm3())

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 4.76550, places=5)
        self.assertEqual(hits[9].docid, 'CACM-2516')
        self.assertAlmostEqual(hits[9].score, 4.21740, places=5)

        self.searcher.set_rm3(fb_docs=4, fb_terms=6, original_query_weight=0.3)
        self.assertTrue(self.searcher.is_using_rm3())

        hits = self.searcher.search('information retrieval')

        self.assertEqual(hits[0].docid, 'CACM-3134')
        self.assertAlmostEqual(hits[0].score, 2.17190, places=5)
        self.assertEqual(hits[9].docid, 'CACM-1457')
        self.assertAlmostEqual(hits[9].score, 1.43700, places=5)

    def test_doc_int(self):
        # The doc method is overloaded: if input is int, it's assumed to be a Lucene internal docid.
        doc = self.searcher.doc(1)
        self.assertTrue(isinstance(doc, Document))

        # These are all equivalent ways to get the docid.
        self.assertEqual('CACM-0002', doc.id())
        self.assertEqual('CACM-0002', doc.docid())
        self.assertEqual('CACM-0002', doc.get('id'))
        self.assertEqual('CACM-0002',
                         doc.lucene_document().getField('id').stringValue())

        # These are all equivalent ways to get the 'raw' field
        self.assertEqual(186, len(doc.raw()))
        self.assertEqual(186, len(doc.get('raw')))
        self.assertEqual(186, len(doc.lucene_document().get('raw')))
        self.assertEqual(
            186, len(doc.lucene_document().getField('raw').stringValue()))

        # These are all equivalent ways to get the 'contents' field
        self.assertEqual(154, len(doc.contents()))
        self.assertEqual(154, len(doc.get('contents')))
        self.assertEqual(154, len(doc.lucene_document().get('contents')))
        self.assertEqual(
            154, len(doc.lucene_document().getField('contents').stringValue()))

        # Should return None if we request a docid that doesn't exist
        self.assertTrue(self.searcher.doc(314159) is None)

    def test_doc_str(self):
        # The doc method is overloaded: if input is str, it's assumed to be an external collection docid.
        doc = self.searcher.doc('CACM-0002')
        self.assertTrue(isinstance(doc, Document))

        # These are all equivalent ways to get the docid.
        self.assertEqual(doc.lucene_document().getField('id').stringValue(),
                         'CACM-0002')
        self.assertEqual(doc.id(), 'CACM-0002')
        self.assertEqual(doc.docid(), 'CACM-0002')
        self.assertEqual(doc.get('id'), 'CACM-0002')

        # These are all equivalent ways to get the 'raw' field
        self.assertEqual(186, len(doc.raw()))
        self.assertEqual(186, len(doc.get('raw')))
        self.assertEqual(186, len(doc.lucene_document().get('raw')))
        self.assertEqual(
            186, len(doc.lucene_document().getField('raw').stringValue()))

        # These are all equivalent ways to get the 'contents' field
        self.assertEqual(154, len(doc.contents()))
        self.assertEqual(154, len(doc.get('contents')))
        self.assertEqual(154, len(doc.lucene_document().get('contents')))
        self.assertEqual(
            154, len(doc.lucene_document().getField('contents').stringValue()))

        # Should return None if we request a docid that doesn't exist
        self.assertTrue(self.searcher.doc('foo') is None)

    def test_doc_by_field(self):
        self.assertEqual(
            self.searcher.doc('CACM-3134').docid(),
            self.searcher.doc_by_field('id', 'CACM-3134').docid())

        # Should return None if we request a docid that doesn't exist
        self.assertTrue(self.searcher.doc_by_field('foo', 'bar') is None)

    def tearDown(self):
        self.searcher.close()
        os.remove(self.tarball_name)
        shutil.rmtree(self.index_dir)
Exemplo n.º 4
0
	queries = {}
	with open(args.input) as f:
		data = json.load(f)
		for query in data["data"]:
			queries[query["id"]] = Query(query["id"], query["question"], query["answers"])

	# Find top documents for each query
	ranked_queries = {}

	if args.batch_size <= 1 and args.threads <= 1:
		for qid, q in queries.items():
			hits = searcher.search(q.question, 1000)
			ranked_queries[q] = hits
	else:
		for qs in batch(list(queries.values()), args.batch_size):
			hits = searcher.batch_search([q.question for q in qs], [q.id for q in qs], 1000, args.threads)
			hits = {queries[q]: v for q, v in hits.items()}
			ranked_queries.update(hits)

	output_dict = {}
	for q, hits in ranked_queries.items():
		output_dict[q.id] = {"question": q.question, "answers": q.answers, "contexts": []}
		for hit in hits:
			docid = hit.docid.strip()
			ctx = json.loads(searcher.doc(docid).raw())['contents']
			out = {'docid': docid, 'score': hit.score, 'text': ctx}
			output_dict[q.id]["contexts"].append(out)

	with open(args.output, "w+") as f:
		json.dump(output_dict, f, indent=4)
Exemplo n.º 5
0
                                      start_time) / (line_number + 1)
                    print(
                        f'Retrieving query {line_number} ({time_per_query:0.3f} s/query)',
                        flush=True)
                for rank in range(len(hits)):
                    docno = hits[rank].docid
                    fout.write('{}\t{}\t{}\n'.format(qid, docno, rank + 1))
    else:
        qids = []
        queries = []
        result_dict = {}

        for line_number, line in enumerate(
                open(args.queries, 'r', encoding='utf8')):
            qid, query = line.strip().split('\t')
            qids.append(qid)
            queries.append(query)

        results = searcher.batch_search(queries, qids, args.hits, args.threads)

        with open(args.output, 'w') as fout:
            for qid in qids:
                hits = results.get(qid)
                for rank in range(len(hits)):
                    docno = hits[rank].docid
                    fout.write(f'{qid}\t{docno}\t{rank+1}\n')

    total_time = (time.time() - total_start_time)
    print(f'Total retrieval time: {total_time:0.3f} s')
    print('Done!')
Exemplo n.º 6
0
class PyseriniRanker(Component):
    def __init__(
        self,
        index_folder: str,
        n_threads: int = 1,
        top_n: int = 5,
        text_column_name: str = "contents",
        return_scores: bool = False,
        *args,
        **kwargs,
    ):
        self.searcher = SimpleSearcher(str(expand_path(index_folder)))
        self.n_threads = n_threads
        self.top_n = top_n
        self.text_column_name = text_column_name
        self.return_scores = return_scores

    def __call__(self, questions: List[str]) -> Tuple[List[Any], List[float]]:
        docs_batch = []
        scores_batch = []
        _doc_ids_batch = []

        if len(questions) == 1:
            for question in questions:
                res = self.searcher.search(question, self.top_n)
                docs, doc_ids, scores = self._processing_search_result(res)
                docs_batch.append(docs)
                scores_batch.append(scores)
                _doc_ids_batch.append(doc_ids)
        else:
            n_batches = len(questions) // self.n_threads + int(
                len(questions) % self.n_threads > 0)
            for i in range(n_batches):
                questions_cur = questions[i * self.n_threads:(i + 1) *
                                          self.n_threads]
                qids_cur = list(range(len(questions_cur)))
                res_batch = self.searcher.batch_search(questions_cur, qids_cur,
                                                       self.top_n,
                                                       self.n_threads)
                for qid in qids_cur:
                    res = res_batch.get(qid)
                    docs, doc_ids, scores = self._processing_search_result(res)
                    docs_batch.append(docs)
                    scores_batch.append(scores)
                    _doc_ids_batch.append(doc_ids)

        logger.debug(f"found docs {_doc_ids_batch}")

        if self.return_scores:
            return docs_batch, scores_batch
        else:
            return docs_batch

    @staticmethod
    def _processing_search_result(res):
        docs = []
        doc_ids = []
        scores = []
        for elem in res:
            doc = json.loads(elem.raw)
            score = elem.score
            if doc and isinstance(doc, dict):
                docs.append(doc.get("contents", ""))
                doc_ids.append(doc.get("id", ""))
                scores.append(score)

        return docs, doc_ids, scores
Exemplo n.º 7
0
class BaselineBM25():
    def __init__(
        self,
        k,
        index_loc='../../anserini/indexes/lucene-wapost.v2.pos+docvectors+raw'
    ):
        self.utils = Utils()
        # Make sure you have produced this lucene index before
        self.index_loc = index_loc
        self.searcher = SimpleSearcher(self.index_loc)
        self.k = k  # number of hits to return
        self.searcher.set_bm25(k1=0.9, b=0.4)  # BM25 params
        #searcher.set_rm3(10, 10, 0.5)  # relevance feedback
        self.batch_hits = {}
        self.topics = get_topics('core18')
        self.query_ids = [str(id) for id in self.topics.keys()]
        self.queries = [topic['title'] for topic in self.topics.values()]
        self.doc_ids = {}
        self.scores = {}

    def rank(self):
        print("Ranking for all topics in progress ...")
        # Perform batch search on all queries
        # hits contains: docid, retrieval score, and document content
        self.batch_hits = self.searcher.batch_search(self.queries,
                                                     self.query_ids,
                                                     k=self.k)

        # Inspect results for first query
        #print("Scores for first query:")
        #self.utils.print_top_n_results(self.batch_hits[self.query_ids[0]], 10)

        # Produce a file suitable to be used with trec-eval
        self.doc_ids = {
            query_id: [hit.docid for hit in hits]
            for query_id, hits in self.batch_hits.items()
        }
        self.scores = {
            query_id: [hit.score for hit in hits]
            for query_id, hits in self.batch_hits.items()
        }
        run_name = f"BASELINE-N{self.k}"
        self.utils.write_rankings(self.query_ids, self.doc_ids, self.scores,
                                  run_name)

    def get_topics(self):
        return self.topics

    def get_batch_hits(self):
        return self.batch_hits

    def get_index_loc(self):
        return self.index_loc

    def get_query_ids(self):
        return self.query_ids

    def get_queries(self):
        return self.queries

    def get_doc_ids(self):
        return self.doc_ids

    def get_scores(self):
        return self.scores