def get_dummy_pytorch_distributed_retriever( self, init_retrieval, port=12345) -> RagPyTorchDistributedRetriever: dataset = Dataset.from_dict({ "id": ["0", "1"], "text": ["foo", "bar"], "title": ["Foo", "Bar"], "embeddings": [ np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size) ], }) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), ) with patch("transformers.retrieval_rag.load_dataset" ) as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagPyTorchDistributedRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) if init_retrieval: retriever.init_retrieval(port) return retriever
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="custom", ) if from_disk: config.passages_path = os.path.join(self.tmpdirname, "dataset") config.index_path = os.path.join(self.tmpdirname, "index.faiss") dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss")) dataset.drop_index("embeddings") dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) del dataset retriever = RagPyTorchDistributedRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) else: retriever = RagPyTorchDistributedRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), index=CustomHFIndex(config.retrieval_vector_size, dataset), ) if init_retrieval: retriever.init_retrieval(port) return retriever
def test_save_load_pretrained_with_saved_config(self): save_dir = os.path.join(self.tmpdirname, "rag_tokenizer") rag_config = RagConfig(question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict()) rag_tokenizer = RagTokenizer(question_encoder=self.get_dpr_tokenizer(), generator=self.get_bart_tokenizer()) rag_config.save_pretrained(save_dir) rag_tokenizer.save_pretrained(save_dir) new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config) self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizer) self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab) self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer) self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)
def get_dummy_canonical_hf_index_retriever(self): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), ) with patch("transformers.retrieval_rag.load_dataset" ) as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) return retriever
def get_dummy_pytorch_distributed_retriever( self, init_retrieval: bool, port=12345 ) -> RagPyTorchDistributedRetriever: dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), ) with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset: mock_load_dataset.return_value = dataset retriever = RagPyTorchDistributedRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) if init_retrieval: retriever.init_retrieval(port) return retriever
def get_dummy_legacy_index_retriever(self): dataset = Dataset.from_dict({ "id": ["0", "1"], "text": ["foo", "bar"], "title": ["Foo", "Bar"], "embeddings": [ np.ones(self.retrieval_vector_size + 1), 2 * np.ones(self.retrieval_vector_size + 1) ], }) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) index_file_name = os.path.join( self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index") dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr") pickle.dump(dataset["id"], open(index_file_name + ".index_meta.dpr", "wb")) passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl") passages = { sample["id"]: [sample["text"], sample["title"]] for sample in dataset } pickle.dump(passages, open(passages_file_name, "wb")) config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="legacy", index_path=self.tmpdirname, passages_path=self.tmpdirname, ) retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()) return retriever