def get_dummy_custom_hf_index_retriever(self, from_disk: bool): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="custom", ) if from_disk: config.passages_path = os.path.join(self.tmpdirname, "dataset") config.index_path = os.path.join(self.tmpdirname, "index.faiss") dataset.get_index("embeddings").save( os.path.join(self.tmpdirname, "index.faiss")) dataset.drop_index("embeddings") dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) del dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) else: retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), index=CustomHFIndex(config.retrieval_vector_size, dataset), ) return retriever
def get_dummy_hf_index_retriever(self): dataset = Dataset.from_dict({ "id": ["0", "1"], "text": ["foo", "bar"], "title": ["Foo", "Bar"], "embeddings": [ np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size) ], }) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), ) with patch("transformers.retrieval_rag.load_dataset" ) as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) return retriever
def test_legacy_hf_index_retriever_save_and_from_pretrained(self): retriever = self.get_dummy_legacy_index_retriever() with tempfile.TemporaryDirectory() as tmp_dirname: retriever.save_pretrained(tmp_dirname) retriever = RagRetriever.from_pretrained(tmp_dirname) self.assertIsInstance(retriever, RagRetriever) hidden_states = np.array([ np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size) ], dtype=np.float32) out = retriever.retrieve(hidden_states, n_docs=1) self.assertTrue(out is not None)
def test_canonical_hf_index_retriever_save_and_from_pretrained(self): retriever = self.get_dummy_canonical_hf_index_retriever() with tempfile.TemporaryDirectory() as tmp_dirname: with patch("transformers.retrieval_rag.load_dataset" ) as mock_load_dataset: mock_load_dataset.return_value = self.get_dummy_dataset() retriever.save_pretrained(tmp_dirname) retriever = RagRetriever.from_pretrained(tmp_dirname) self.assertIsInstance(retriever, RagRetriever) hidden_states = np.array([ np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size) ], dtype=np.float32) out = retriever.retrieve(hidden_states, n_docs=1) self.assertTrue(out is not None)
def get_dummy_canonical_hf_index_retriever(self): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), ) with patch("transformers.retrieval_rag.load_dataset" ) as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) return retriever
def get_dummy_legacy_index_retriever(self): dataset = Dataset.from_dict({ "id": ["0", "1"], "text": ["foo", "bar"], "title": ["Foo", "Bar"], "embeddings": [ np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1) ], }) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) index_file_name = os.path.join( self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index") dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr") pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb")) passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl") passages = { sample["id"]: [sample["text"], sample["title"]] for sample in dataset } pickle.dump(passages, open(passages_file_name, "wb")) config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="legacy", index_path=self.tmpdirname, passages_path=self.tmpdirname, ) retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()) return retriever