Example #1
0
    def test_faiss_pq(self):
        cache_dir = get_cache_home()
        index_dir = f'{cache_dir}/temp_pq'
        encoded_corpus_dir = self.prepare_encoded_collection()
        cmd = f'python -m pyserini.index.faiss \
            --input {encoded_corpus_dir} \
            --output {index_dir} \
            --pq-m 3 \
            --efC 1 \
            --pq-nbits 128 \
            --pq'

        status = os.system(cmd)
        self.assertEqual(status, 0)

        docid_fn = os.path.join(index_dir, 'docid')
        index_fn = os.path.join(index_dir, 'index')
        self.assertIsFile(docid_fn)
        self.assertIsFile(index_fn)

        index = faiss.read_index(index_fn)
        vectors = index.reconstruct_n(0, index.ntotal)

        with open(docid_fn) as f:
            self.assertListEqual([docid.strip() for docid in f], self.docids)

        self.assertAlmostEqual(vectors[0][0], 0.04343192, places=4)
        self.assertAlmostEqual(vectors[0][-1], 0.075478144, places=4)
        self.assertAlmostEqual(vectors[2][0], 0.04343192, places=4)
        self.assertAlmostEqual(vectors[2][-1], 0.075478144, places=4)
Example #2
0
    def test_faiss_hnsw(self):
        cache_dir = get_cache_home()
        index_dir = f'{cache_dir}/temp_hnsw'
        encoded_corpus_dir = self.prepare_encoded_collection()
        cmd = f'python -m pyserini.index.faiss \
            --input {encoded_corpus_dir} \
            --output {index_dir} \
            --M 3 \
            --hnsw'

        status = os.system(cmd)
        self.assertEqual(status, 0)

        docid_fn = os.path.join(index_dir, 'docid')
        index_fn = os.path.join(index_dir, 'index')
        self.assertIsFile(docid_fn)
        self.assertIsFile(index_fn)

        index = faiss.read_index(index_fn)
        vectors = index.reconstruct_n(0, index.ntotal)

        with open(docid_fn) as f:
            self.assertListEqual([docid.strip() for docid in f], self.docids)

        self.assertAlmostEqual(vectors[0][0], 0.12679848074913025, places=4)
        self.assertAlmostEqual(vectors[0][-1],
                               -0.0037349488120526075,
                               places=4)
        self.assertAlmostEqual(vectors[2][0], 0.03678430616855621, places=4)
        self.assertAlmostEqual(vectors[2][-1], 0.13209162652492523, places=4)
Example #3
0
 def __init__(self, model: str, ibm_model:str, index:str, data: str, prebuilt: bool):
     #msmarco-ltr-passage
     self.model = model
     self.ibm_model = ibm_model
     if prebuilt:
         self.lucene_searcher = LuceneSearcher.from_prebuilt_index(index)
         index_directory = os.path.join(get_cache_home(), 'indexes')
         if data == 'passage':
             index_path = os.path.join(index_directory, 'index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3')
         else:
             index_path = os.path.join(index_directory, 'index-msmarco-doc-per-passage-ltr-20211031-33e4151.bd60e89041b4ebbabc4bf0cfac608a87')
         self.index_reader = IndexReader.from_prebuilt_index(index)
     else:
         index_path = index
         self.index_reader = IndexReader(index)
     self.fe = FeatureExtractor(index_path, max(multiprocessing.cpu_count()//2, 1))
     self.data = data
Example #4
0
 def download_kilt_topics(cls, task: str, force=False):
     if task not in KILT_QUERY_INFO:
         raise ValueError(f'Unrecognized query name {task}')
     task = KILT_QUERY_INFO[task]
     md5 = task['md5']
     save_dir = os.path.join(get_cache_home(), 'queries')
     if not os.path.exists(save_dir):
         os.makedirs(save_dir)
     for url in task['urls']:
         try:
             return download_url(url, save_dir, force=force, md5=md5)
         except (HTTPError, URLError) as e:
             print(
                 f'Unable to download encoded query at {url}, trying next URL...'
             )
     raise ValueError(
         f'Unable to download encoded query at any known URLs.')
Example #5
0
    def prepare_encoded_collection(self):
        cache_dir = get_cache_home()
        encoded_corpus_dir = f'{cache_dir}/temp_index'
        cmd = f'python -m pyserini.encode \
                input   --corpus {self.test_file} \
                        --fields text \
                output  --embeddings {encoded_corpus_dir} \
                        --to-faiss \
                encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \
                        --fields text \
                        --batch 1 \
                        --device cpu'

        status = os.system(cmd)
        self.assertEqual(status, 0)
        self.assertIsFile(os.path.join(encoded_corpus_dir, 'docid'))
        self.assertIsFile(os.path.join(encoded_corpus_dir, 'index'))
        return encoded_corpus_dir
Example #6
0
    def test_tct_colbert_v2_encoder_cmd_shard(self):
        cache_dir = get_cache_home()

        for shard_i in range(2):
            index_dir = f'{cache_dir}/temp_index-{shard_i}'
            cmd = f'python -m pyserini.encode \
                    input   --corpus {self.test_file} \
                            --fields text \
                            --shard-id {shard_i} \
                            --shard-num 2 \
                    output  --embeddings {index_dir} \
                            --to-faiss \
                    encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \
                            --fields text \
                            --batch 1 \
                            --device cpu'

            status = os.system(cmd)
            self.assertEqual(status, 0)
            self.assertIsFile(os.path.join(index_dir, 'docid'))
            self.assertIsFile(os.path.join(index_dir, 'index'))

        cmd = f'python -m pyserini.index.merge_faiss_indexes --prefix {cache_dir}/temp_index- --shard-num 2'
        index_dir = f'{cache_dir}/temp_index-full'
        docid_fn = os.path.join(index_dir, 'docid')
        index_fn = os.path.join(index_dir, 'index')

        status = os.system(cmd)
        self.assertEqual(status, 0)
        self.assertIsFile(docid_fn)
        self.assertIsFile(index_fn)

        index = faiss.read_index(index_fn)
        vectors = index.reconstruct_n(0, index.ntotal)

        with open(docid_fn) as f:
            self.assertListEqual([docid.strip() for docid in f], self.docids)

        self.assertAlmostEqual(vectors[0][0], 0.12679848074913025, places=4)
        self.assertAlmostEqual(vectors[0][-1],
                               -0.0037349488120526075,
                               places=4)
        self.assertAlmostEqual(vectors[2][0], 0.03678430616855621, places=4)
        self.assertAlmostEqual(vectors[2][-1], 0.13209162652492523, places=4)
Example #7
0
    def test_tct_colbert_v2_encoder_cmd(self):
        cache_dir = get_cache_home()
        index_dir = f'{cache_dir}/temp_index'
        cmd = f'python -m pyserini.encode \
                  input   --corpus {self.test_file} \
                          --fields text \
                  output  --embeddings {index_dir} \
                  encoder --encoder castorini/tct_colbert-v2-hnp-msmarco \
                          --fields text \
                          --batch 1 \
                          --device cpu'

        status = os.system(cmd)
        self.assertEqual(status, 0)

        embedding_json_fn = os.path.join(index_dir, 'embeddings.jsonl')
        self.assertIsFile(embedding_json_fn)

        with open(embedding_json_fn) as f:
            embeddings = [json.loads(line) for line in f]

        self.assertListEqual([entry["id"] for entry in embeddings],
                             self.docids)
        self.assertListEqual(
            [entry["contents"] for entry in embeddings],
            [entry.strip() for entry in self.texts],
        )

        self.assertAlmostEqual(embeddings[0]['vector'][0],
                               0.12679848074913025,
                               places=4)
        self.assertAlmostEqual(embeddings[0]['vector'][-1],
                               -0.0037349488120526075,
                               places=4)
        self.assertAlmostEqual(embeddings[2]['vector'][0],
                               0.03678430616855621,
                               places=4)
        self.assertAlmostEqual(embeddings[2]['vector'][-1],
                               0.13209162652492523,
                               places=4)
Example #8
0
 def __init__(self, ibm_model: str, index: str, field_name: str):
     self.ibm_model = ibm_model
     self.bm25search = LuceneSearcher.from_prebuilt_index(index)
     index_directory = os.path.join(get_cache_home(), 'indexes')
     if (index == 'msmarco-passage-ltr'):
         index_path = os.path.join(
             index_directory,
             'index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3'
         )
     elif (index == 'msmarco-document-segment-ltr'):
         index_path = os.path.join(
             index_directory,
             'lucene-index.msmarco-doc-segmented.ibm.13064bdaf8e8a79222634d67ecd3ddb5'
         )
     else:
         print(
             "We currently only support two indexes: msmarco-passage-ltr and msmarco-document-segment-ltr, \
         but the index you inserted is not one of those")
     self.object = JLuceneSearcher(index_path)
     self.index_reader = JIndexReader().getReader(index_path)
     self.field_name = field_name
     self.source_lookup, self.target_lookup, self.tran = self.load_tranprobs_table(
     )
     self.pool = ThreadPool(24)
Example #9
0
def get_qrels_file(collection_name):
    """
    Parameters
    ----------
    collection_name : str
        collection_name

    Returns
    -------
    path : str
        path of the qrels file
    """
    qrels = None
    if collection_name == 'trec1-adhoc':
        qrels = JQrels.TREC1_ADHOC
    elif collection_name == 'trec2-adhoc':
        qrels = JQrels.TREC2_ADHOC
    elif collection_name == 'trec3-adhoc':
        qrels = JQrels.TREC3_ADHOC
    elif collection_name == 'robust04':
        qrels = JQrels.ROBUST04
    elif collection_name == 'robust05':
        qrels = JQrels.ROBUST05
    elif collection_name == 'core17':
        qrels = JQrels.CORE17
    elif collection_name == 'core18':
        qrels = JQrels.CORE18
    elif collection_name == 'wt10g':
        qrels = JQrels.WT10G
    elif collection_name == 'trec2004-terabyte':
        qrels = JQrels.TREC2004_TERABYTE
    elif collection_name == 'trec2005-terabyte':
        qrels = JQrels.TREC2005_TERABYTE
    elif collection_name == 'trec2006-terabyte':
        qrels = JQrels.TREC2006_TERABYTE
    elif collection_name == 'trec2011-web':
        qrels = JQrels.TREC2011_WEB
    elif collection_name == 'trec2012-web':
        qrels = JQrels.TREC2012_WEB
    elif collection_name == 'trec2013-web':
        qrels = JQrels.TREC2013_WEB
    elif collection_name == 'trec2014-web':
        qrels = JQrels.TREC2014_WEB
    elif collection_name == 'mb11':
        qrels = JQrels.MB11
    elif collection_name == 'mb12':
        qrels = JQrels.MB12
    elif collection_name == 'mb13':
        qrels = JQrels.MB13
    elif collection_name == 'mb14':
        qrels = JQrels.MB14
    elif collection_name == 'car17v1.5-benchmarkY1test':
        qrels = JQrels.CAR17V15_BENCHMARK_Y1_TEST
    elif collection_name == 'car17v2.0-benchmarkY1test':
        qrels = JQrels.CAR17V20_BENCHMARK_Y1_TEST
    elif collection_name == 'dl19-doc':
        qrels = JQrels.TREC2019_DL_DOC
    elif collection_name == 'dl19-passage':
        qrels = JQrels.TREC2019_DL_PASSAGE
    elif collection_name == 'dl20-doc':
        qrels = JQrels.TREC2020_DL_DOC
    elif collection_name == 'dl20-passage':
        qrels = JQrels.TREC2020_DL_PASSAGE
    elif collection_name == 'msmarco-doc-dev':
        qrels = JQrels.MSMARCO_DOC_DEV
    elif collection_name == 'msmarco-passage-dev-subset':
        qrels = JQrels.MSMARCO_PASSAGE_DEV_SUBSET
    elif collection_name == 'ntcir8-zh':
        qrels = JQrels.NTCIR8_ZH
    elif collection_name == 'clef2006-fr':
        qrels = JQrels.CLEF2006_FR
    elif collection_name == 'trec2002-ar':
        qrels = JQrels.TREC2002_AR
    elif collection_name == 'fire2012-bn':
        qrels = JQrels.FIRE2012_BN
    elif collection_name == 'fire2012-hi':
        qrels = JQrels.FIRE2012_HI
    elif collection_name == 'fire2012-en':
        qrels = JQrels.FIRE2012_EN
    elif collection_name == 'covid-complete':
        qrels = JQrels.COVID_COMPLETE
    elif collection_name == 'covid-round1':
        qrels = JQrels.COVID_ROUND1
    elif collection_name == 'covid-round2':
        qrels = JQrels.COVID_ROUND2
    elif collection_name == 'covid-round3':
        qrels = JQrels.COVID_ROUND3
    elif collection_name == 'covid-round3-cumulative':
        qrels = JQrels.COVID_ROUND3_CUMULATIVE
    elif collection_name == 'covid-round4':
        qrels = JQrels.COVID_ROUND4
    elif collection_name == 'covid-round4-cumulative':
        qrels = JQrels.COVID_ROUND4_CUMULATIVE
    elif collection_name == 'covid-round5':
        qrels = JQrels.COVID_ROUND5
    elif collection_name == 'trec2018-bl':
        qrels = JQrels.TREC2018_BL
    elif collection_name == 'trec2019-bl':
        qrels = JQrels.TREC2019_BL
    if qrels:
        target_path = os.path.join(get_cache_home(), qrels.path)
        if os.path.exists(target_path):
            return target_path
        target_dir = os.path.split(target_path)[0]
        if not os.path.exists(target_dir):
            os.makedirs(target_dir)
        with open(target_path, 'w') as file:
            qrels_content = JRelevanceJudgments.getQrelsResource(qrels)
            file.write(qrels_content)
        return target_path
    raise FileNotFoundError(f'no qrels file for {collection_name}')
Example #10
0
def get_qrels_file(collection_name):
    """
    Parameters
    ----------
    collection_name : str
        collection_name

    Returns
    -------
    path : str
        path of the qrels file
    """
    qrels = None
    if collection_name == 'robust04':
        qrels = JQrels.ROBUST04
    elif collection_name == 'robust05':
        qrels = JQrels.ROBUST05
    elif collection_name == 'core17':
        qrels = JQrels.CORE17
    elif collection_name == 'core18':
        qrels = JQrels.CORE18
    elif collection_name == 'car17v1.5-benchmarkY1test':
        qrels = JQrels.CAR17V15_BENCHMARK_Y1_TEST
    elif collection_name == 'car17v2.0-benchmarkY1test':
        qrels = JQrels.CAR17V20_BENCHMARK_Y1_TEST
    elif collection_name == 'dl19-doc':
        qrels = JQrels.TREC2019_DL_DOC
    elif collection_name == 'dl19-passage':
        qrels = JQrels.TREC2019_DL_PASSAGE
    elif collection_name == 'msmarco-doc-dev':
        qrels = JQrels.MSMARCO_DOC_DEV
    elif collection_name == 'msmarco-passage-dev-subset':
        qrels = JQrels.MSMARCO_PASSAGE_DEV_SUBSET
    elif collection_name == 'covid-round1':
        qrels = JQrels.COVID_ROUND1
    elif collection_name == 'covid-round2':
        qrels = JQrels.COVID_ROUND2
    elif collection_name == 'covid-round3':
        qrels = JQrels.COVID_ROUND3
    elif collection_name == 'covid-round3-cumulative':
        qrels = JQrels.COVID_ROUND3_CUMULATIVE
    elif collection_name == 'covid-round4':
        qrels = JQrels.COVID_ROUND4
    elif collection_name == 'covid-round4-cumulative':
        qrels = JQrels.COVID_ROUND4_CUMULATIVE
    elif collection_name == 'covid-round5':
        qrels = JQrels.COVID_ROUND5
    elif collection_name == 'covid-complete':
        qrels = JQrels.COVID_COMPLETE
    elif collection_name == 'trec2018-bl':
        qrels = JQrels.TREC2018_BL
    elif collection_name == 'trec2019-bl':
        qrels = JQrels.TREC2019_BL
    if qrels:
        target_path = os.path.join(get_cache_home(), qrels.path)
        if os.path.exists(target_path):
            return target_path
        target_dir = os.path.split(target_path)[0]
        if not os.path.exists(target_dir):
            os.makedirs(target_dir)
        with open(target_path, 'w') as file:
            qrels_content = JRelevanceJudgments.getQrelsResource(qrels)
            file.write(qrels_content)
        return target_path
    raise FileNotFoundError(f'no qrels file for {collection_name}')