def get_dummy_custom_hf_index_retriever(self, from_disk: bool): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="custom", ) if from_disk: config.passages_path = os.path.join(self.tmpdirname, "dataset") config.index_path = os.path.join(self.tmpdirname, "index.faiss") dataset.get_index("embeddings").save( os.path.join(self.tmpdirname, "index.faiss")) dataset.drop_index("embeddings") dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) del dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) else: retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), index=CustomHFIndex(config.retrieval_vector_size, dataset), ) return retriever
def distributed_retriever_check(self, retriever: RagRetriever, hidden_states: np.array, n_docs: int) -> None: retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs) self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size)) self.assertEqual(len(doc_dicts), 2) self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"]) self.assertEqual(len(doc_dicts[0]["id"]), n_docs) self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_legacy_hf_index_retriever_save_and_from_pretrained(self): retriever = self.get_dummy_legacy_index_retriever() with tempfile.TemporaryDirectory() as tmp_dirname: retriever.save_pretrained(tmp_dirname) retriever = RagRetriever.from_pretrained(tmp_dirname) self.assertIsInstance(retriever, RagRetriever) hidden_states = np.array([ np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size) ], dtype=np.float32) out = retriever.retrieve(hidden_states, n_docs=1) self.assertTrue(out is not None)
def test_canonical_hf_index_retriever_save_and_from_pretrained(self): retriever = self.get_dummy_canonical_hf_index_retriever() with tempfile.TemporaryDirectory() as tmp_dirname: with patch("transformers.models.rag.retrieval_rag.load_dataset" ) as mock_load_dataset: mock_load_dataset.return_value = self.get_dummy_dataset() retriever.save_pretrained(tmp_dirname) retriever = RagRetriever.from_pretrained(tmp_dirname) self.assertIsInstance(retriever, RagRetriever) hidden_states = np.array([ np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size) ], dtype=np.float32) out = retriever.retrieve(hidden_states, n_docs=1) self.assertTrue(out is not None)
def get_dummy_canonical_hf_index_retriever(self): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), ) with patch("transformers.models.rag.retrieval_rag.load_dataset" ) as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) return retriever
def get_dummy_legacy_index_retriever(self): dataset = Dataset.from_dict({ "id": ["0", "1"], "text": ["foo", "bar"], "title": ["Foo", "Bar"], "embeddings": [ np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1) ], }) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) index_file_name = os.path.join( self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index") dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr") pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb")) passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl") passages = { sample["id"]: [sample["text"], sample["title"]] for sample in dataset } pickle.dump(passages, open(passages_file_name, "wb")) config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="legacy", index_path=self.tmpdirname, ) retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()) return retriever