def test_faiss_pq(self): cache_dir = get_cache_home() index_dir = f'{cache_dir}/temp_pq' encoded_corpus_dir = self.prepare_encoded_collection() cmd = f'python -m pyserini.index.faiss \ --input {encoded_corpus_dir} \ --output {index_dir} \ --pq-m 3 \ --efC 1 \ --pq-nbits 128 \ --pq' status = os.system(cmd) self.assertEqual(status, 0) docid_fn = os.path.join(index_dir, 'docid') index_fn = os.path.join(index_dir, 'index') self.assertIsFile(docid_fn) self.assertIsFile(index_fn) index = faiss.read_index(index_fn) vectors = index.reconstruct_n(0, index.ntotal) with open(docid_fn) as f: self.assertListEqual([docid.strip() for docid in f], self.docids) self.assertAlmostEqual(vectors[0][0], 0.04343192, places=4) self.assertAlmostEqual(vectors[0][-1], 0.075478144, places=4) self.assertAlmostEqual(vectors[2][0], 0.04343192, places=4) self.assertAlmostEqual(vectors[2][-1], 0.075478144, places=4)
def test_faiss_hnsw(self): cache_dir = get_cache_home() index_dir = f'{cache_dir}/temp_hnsw' encoded_corpus_dir = self.prepare_encoded_collection() cmd = f'python -m pyserini.index.faiss \ --input {encoded_corpus_dir} \ --output {index_dir} \ --M 3 \ --hnsw' status = os.system(cmd) self.assertEqual(status, 0) docid_fn = os.path.join(index_dir, 'docid') index_fn = os.path.join(index_dir, 'index') self.assertIsFile(docid_fn) self.assertIsFile(index_fn) index = faiss.read_index(index_fn) vectors = index.reconstruct_n(0, index.ntotal) with open(docid_fn) as f: self.assertListEqual([docid.strip() for docid in f], self.docids) self.assertAlmostEqual(vectors[0][0], 0.12679848074913025, places=4) self.assertAlmostEqual(vectors[0][-1], -0.0037349488120526075, places=4) self.assertAlmostEqual(vectors[2][0], 0.03678430616855621, places=4) self.assertAlmostEqual(vectors[2][-1], 0.13209162652492523, places=4)
def __init__(self, model: str, ibm_model:str, index:str, data: str, prebuilt: bool): #msmarco-ltr-passage self.model = model self.ibm_model = ibm_model if prebuilt: self.lucene_searcher = LuceneSearcher.from_prebuilt_index(index) index_directory = os.path.join(get_cache_home(), 'indexes') if data == 'passage': index_path = os.path.join(index_directory, 'index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3') else: index_path = os.path.join(index_directory, 'index-msmarco-doc-per-passage-ltr-20211031-33e4151.bd60e89041b4ebbabc4bf0cfac608a87') self.index_reader = IndexReader.from_prebuilt_index(index) else: index_path = index self.index_reader = IndexReader(index) self.fe = FeatureExtractor(index_path, max(multiprocessing.cpu_count()//2, 1)) self.data = data
def download_kilt_topics(cls, task: str, force=False): if task not in KILT_QUERY_INFO: raise ValueError(f'Unrecognized query name {task}') task = KILT_QUERY_INFO[task] md5 = task['md5'] save_dir = os.path.join(get_cache_home(), 'queries') if not os.path.exists(save_dir): os.makedirs(save_dir) for url in task['urls']: try: return download_url(url, save_dir, force=force, md5=md5) except (HTTPError, URLError) as e: print( f'Unable to download encoded query at {url}, trying next URL...' ) raise ValueError( f'Unable to download encoded query at any known URLs.')
def prepare_encoded_collection(self): cache_dir = get_cache_home() encoded_corpus_dir = f'{cache_dir}/temp_index' cmd = f'python -m pyserini.encode \ input --corpus {self.test_file} \ --fields text \ output --embeddings {encoded_corpus_dir} \ --to-faiss \ encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \ --fields text \ --batch 1 \ --device cpu' status = os.system(cmd) self.assertEqual(status, 0) self.assertIsFile(os.path.join(encoded_corpus_dir, 'docid')) self.assertIsFile(os.path.join(encoded_corpus_dir, 'index')) return encoded_corpus_dir
def test_tct_colbert_v2_encoder_cmd_shard(self): cache_dir = get_cache_home() for shard_i in range(2): index_dir = f'{cache_dir}/temp_index-{shard_i}' cmd = f'python -m pyserini.encode \ input --corpus {self.test_file} \ --fields text \ --shard-id {shard_i} \ --shard-num 2 \ output --embeddings {index_dir} \ --to-faiss \ encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \ --fields text \ --batch 1 \ --device cpu' status = os.system(cmd) self.assertEqual(status, 0) self.assertIsFile(os.path.join(index_dir, 'docid')) self.assertIsFile(os.path.join(index_dir, 'index')) cmd = f'python -m pyserini.index.merge_faiss_indexes --prefix {cache_dir}/temp_index- --shard-num 2' index_dir = f'{cache_dir}/temp_index-full' docid_fn = os.path.join(index_dir, 'docid') index_fn = os.path.join(index_dir, 'index') status = os.system(cmd) self.assertEqual(status, 0) self.assertIsFile(docid_fn) self.assertIsFile(index_fn) index = faiss.read_index(index_fn) vectors = index.reconstruct_n(0, index.ntotal) with open(docid_fn) as f: self.assertListEqual([docid.strip() for docid in f], self.docids) self.assertAlmostEqual(vectors[0][0], 0.12679848074913025, places=4) self.assertAlmostEqual(vectors[0][-1], -0.0037349488120526075, places=4) self.assertAlmostEqual(vectors[2][0], 0.03678430616855621, places=4) self.assertAlmostEqual(vectors[2][-1], 0.13209162652492523, places=4)
def test_tct_colbert_v2_encoder_cmd(self): cache_dir = get_cache_home() index_dir = f'{cache_dir}/temp_index' cmd = f'python -m pyserini.encode \ input --corpus {self.test_file} \ --fields text \ output --embeddings {index_dir} \ encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \ --fields text \ --batch 1 \ --device cpu' status = os.system(cmd) self.assertEqual(status, 0) embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl') self.assertIsFile(embedding_json_fn) with open(embedding_json_fn) as f: embeddings = [json.loads(line) for line in f] self.assertListEqual([entry["id"] for entry in embeddings], self.docids) self.assertListEqual( [entry["contents"] for entry in embeddings], [entry.strip() for entry in self.texts], ) self.assertAlmostEqual(embeddings[0]['vector'][0], 0.12679848074913025, places=4) self.assertAlmostEqual(embeddings[0]['vector'][-1], -0.0037349488120526075, places=4) self.assertAlmostEqual(embeddings[2]['vector'][0], 0.03678430616855621, places=4) self.assertAlmostEqual(embeddings[2]['vector'][-1], 0.13209162652492523, places=4)
def __init__(self, ibm_model: str, index: str, field_name: str): self.ibm_model = ibm_model self.bm25search = LuceneSearcher.from_prebuilt_index(index) index_directory = os.path.join(get_cache_home(), 'indexes') if (index == 'msmarco-passage-ltr'): index_path = os.path.join( index_directory, 'index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3' ) elif (index == 'msmarco-document-segment-ltr'): index_path = os.path.join( index_directory, 'lucene-index.msmarco-doc-segmented.ibm.13064bdaf8e8a79222634d67ecd3ddb5' ) else: print( "We currently only support two indexes: msmarco-passage-ltr and msmarco-document-segment-ltr, \ but the index you inserted is not one of those") self.object = JLuceneSearcher(index_path) self.index_reader = JIndexReader().getReader(index_path) self.field_name = field_name self.source_lookup, self.target_lookup, self.tran = self.load_tranprobs_table( ) self.pool = ThreadPool(24)
def get_qrels_file(collection_name): """ Parameters ---------- collection_name : str collection_name Returns ------- path : str path of the qrels file """ qrels = None if collection_name == 'trec1-adhoc': qrels = JQrels.TREC1_ADHOC elif collection_name == 'trec2-adhoc': qrels = JQrels.TREC2_ADHOC elif collection_name == 'trec3-adhoc': qrels = JQrels.TREC3_ADHOC elif collection_name == 'robust04': qrels = JQrels.ROBUST04 elif collection_name == 'robust05': qrels = JQrels.ROBUST05 elif collection_name == 'core17': qrels = JQrels.CORE17 elif collection_name == 'core18': qrels = JQrels.CORE18 elif collection_name == 'wt10g': qrels = JQrels.WT10G elif collection_name == 'trec2004-terabyte': qrels = JQrels.TREC2004_TERABYTE elif collection_name == 'trec2005-terabyte': qrels = JQrels.TREC2005_TERABYTE elif collection_name == 'trec2006-terabyte': qrels = JQrels.TREC2006_TERABYTE elif collection_name == 'trec2011-web': qrels = JQrels.TREC2011_WEB elif collection_name == 'trec2012-web': qrels = JQrels.TREC2012_WEB elif collection_name == 'trec2013-web': qrels = JQrels.TREC2013_WEB elif collection_name == 'trec2014-web': qrels = JQrels.TREC2014_WEB elif collection_name == 'mb11': qrels = JQrels.MB11 elif collection_name == 'mb12': qrels = JQrels.MB12 elif collection_name == 'mb13': qrels = JQrels.MB13 elif collection_name == 'mb14': qrels = JQrels.MB14 elif collection_name == 'car17v1.5-benchmarkY1test': qrels = JQrels.CAR17V15_BENCHMARK_Y1_TEST elif collection_name == 'car17v2.0-benchmarkY1test': qrels = JQrels.CAR17V20_BENCHMARK_Y1_TEST elif collection_name == 'dl19-doc': qrels = JQrels.TREC2019_DL_DOC elif collection_name == 'dl19-passage': qrels = JQrels.TREC2019_DL_PASSAGE elif collection_name == 'dl20-doc': qrels = JQrels.TREC2020_DL_DOC elif collection_name == 'dl20-passage': qrels = JQrels.TREC2020_DL_PASSAGE elif collection_name == 'msmarco-doc-dev': qrels = JQrels.MSMARCO_DOC_DEV elif collection_name == 'msmarco-passage-dev-subset': qrels = JQrels.MSMARCO_PASSAGE_DEV_SUBSET elif collection_name == 'ntcir8-zh': qrels = JQrels.NTCIR8_ZH elif collection_name == 'clef2006-fr': qrels = JQrels.CLEF2006_FR elif collection_name == 'trec2002-ar': qrels = JQrels.TREC2002_AR elif collection_name == 'fire2012-bn': qrels = JQrels.FIRE2012_BN elif collection_name == 'fire2012-hi': qrels = JQrels.FIRE2012_HI elif collection_name == 'fire2012-en': qrels = JQrels.FIRE2012_EN elif collection_name == 'covid-complete': qrels = JQrels.COVID_COMPLETE elif collection_name == 'covid-round1': qrels = JQrels.COVID_ROUND1 elif collection_name == 'covid-round2': qrels = JQrels.COVID_ROUND2 elif collection_name == 'covid-round3': qrels = JQrels.COVID_ROUND3 elif collection_name == 'covid-round3-cumulative': qrels = JQrels.COVID_ROUND3_CUMULATIVE elif collection_name == 'covid-round4': qrels = JQrels.COVID_ROUND4 elif collection_name == 'covid-round4-cumulative': qrels = JQrels.COVID_ROUND4_CUMULATIVE elif collection_name == 'covid-round5': qrels = JQrels.COVID_ROUND5 elif collection_name == 'trec2018-bl': qrels = JQrels.TREC2018_BL elif collection_name == 'trec2019-bl': qrels = JQrels.TREC2019_BL if qrels: target_path = os.path.join(get_cache_home(), qrels.path) if os.path.exists(target_path): return target_path target_dir = os.path.split(target_path)[0] if not os.path.exists(target_dir): os.makedirs(target_dir) with open(target_path, 'w') as file: qrels_content = JRelevanceJudgments.getQrelsResource(qrels) file.write(qrels_content) return target_path raise FileNotFoundError(f'no qrels file for {collection_name}')
def get_qrels_file(collection_name): """ Parameters ---------- collection_name : str collection_name Returns ------- path : str path of the qrels file """ qrels = None if collection_name == 'robust04': qrels = JQrels.ROBUST04 elif collection_name == 'robust05': qrels = JQrels.ROBUST05 elif collection_name == 'core17': qrels = JQrels.CORE17 elif collection_name == 'core18': qrels = JQrels.CORE18 elif collection_name == 'car17v1.5-benchmarkY1test': qrels = JQrels.CAR17V15_BENCHMARK_Y1_TEST elif collection_name == 'car17v2.0-benchmarkY1test': qrels = JQrels.CAR17V20_BENCHMARK_Y1_TEST elif collection_name == 'dl19-doc': qrels = JQrels.TREC2019_DL_DOC elif collection_name == 'dl19-passage': qrels = JQrels.TREC2019_DL_PASSAGE elif collection_name == 'msmarco-doc-dev': qrels = JQrels.MSMARCO_DOC_DEV elif collection_name == 'msmarco-passage-dev-subset': qrels = JQrels.MSMARCO_PASSAGE_DEV_SUBSET elif collection_name == 'covid-round1': qrels = JQrels.COVID_ROUND1 elif collection_name == 'covid-round2': qrels = JQrels.COVID_ROUND2 elif collection_name == 'covid-round3': qrels = JQrels.COVID_ROUND3 elif collection_name == 'covid-round3-cumulative': qrels = JQrels.COVID_ROUND3_CUMULATIVE elif collection_name == 'covid-round4': qrels = JQrels.COVID_ROUND4 elif collection_name == 'covid-round4-cumulative': qrels = JQrels.COVID_ROUND4_CUMULATIVE elif collection_name == 'covid-round5': qrels = JQrels.COVID_ROUND5 elif collection_name == 'covid-complete': qrels = JQrels.COVID_COMPLETE elif collection_name == 'trec2018-bl': qrels = JQrels.TREC2018_BL elif collection_name == 'trec2019-bl': qrels = JQrels.TREC2019_BL if qrels: target_path = os.path.join(get_cache_home(), qrels.path) if os.path.exists(target_path): return target_path target_dir = os.path.split(target_path)[0] if not os.path.exists(target_dir): os.makedirs(target_dir) with open(target_path, 'w') as file: qrels_content = JRelevanceJudgments.getQrelsResource(qrels) file.write(qrels_content) return target_path raise FileNotFoundError(f'no qrels file for {collection_name}')