def get_dummy_pytorch_distributed_retriever(
         self,
         init_retrieval,
         port=12345) -> RagPyTorchDistributedRetriever:
     dataset = Dataset.from_dict({
         "id": ["0", "1"],
         "text": ["foo", "bar"],
         "title": ["Foo", "Bar"],
         "embeddings": [
             np.ones(self.retrieval_vector_size),
             2 * np.ones(self.retrieval_vector_size)
         ],
     })
     dataset.add_faiss_index("embeddings",
                             string_factory="Flat",
                             metric_type=faiss.METRIC_INNER_PRODUCT)
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     with patch("transformers.retrieval_rag.load_dataset"
                ) as mock_load_dataset:
         mock_load_dataset.return_value = dataset
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
         if init_retrieval:
             retriever.init_retrieval(port)
     return retriever
Exemple #2
0
 def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     else:
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     if init_retrieval:
         retriever.init_retrieval(port)
     return retriever
    def test_save_load_pretrained_with_saved_config(self):

        save_dir = os.path.join(self.tmpdirname, "rag_tokenizer")
        rag_config = RagConfig(question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict())
        rag_tokenizer = RagTokenizer(question_encoder=self.get_dpr_tokenizer(), generator=self.get_bart_tokenizer())
        rag_config.save_pretrained(save_dir)
        rag_tokenizer.save_pretrained(save_dir)
        new_rag_tokenizer = RagTokenizer.from_pretrained(save_dir, config=rag_config)
        self.assertIsInstance(new_rag_tokenizer.question_encoder, DPRQuestionEncoderTokenizer)
        self.assertEqual(new_rag_tokenizer.question_encoder.vocab, rag_tokenizer.question_encoder.vocab)
        self.assertIsInstance(new_rag_tokenizer.generator, BartTokenizer)
        self.assertEqual(new_rag_tokenizer.generator.encoder, rag_tokenizer.generator.encoder)
Exemple #4
0
 def get_dummy_canonical_hf_index_retriever(self):
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     with patch("transformers.retrieval_rag.load_dataset"
                ) as mock_load_dataset:
         mock_load_dataset.return_value = dataset
         retriever = RagRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     return retriever
Exemple #5
0
 def get_dummy_pytorch_distributed_retriever(
     self, init_retrieval: bool, port=12345
 ) -> RagPyTorchDistributedRetriever:
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
         mock_load_dataset.return_value = dataset
         retriever = RagPyTorchDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
         if init_retrieval:
             retriever.init_retrieval(port)
     return retriever
    def get_dummy_legacy_index_retriever(self):
        dataset = Dataset.from_dict({
            "id": ["0", "1"],
            "text": ["foo", "bar"],
            "title": ["Foo", "Bar"],
            "embeddings": [
                np.ones(self.retrieval_vector_size + 1),
                2 * np.ones(self.retrieval_vector_size + 1)
            ],
        })
        dataset.add_faiss_index("embeddings",
                                string_factory="Flat",
                                metric_type=faiss.METRIC_INNER_PRODUCT)

        index_file_name = os.path.join(
            self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
        dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
        pickle.dump(dataset["id"],
                    open(index_file_name + ".index_meta.dpr", "wb"))

        passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
        passages = {
            sample["id"]: [sample["text"], sample["title"]]
            for sample in dataset
        }
        pickle.dump(passages, open(passages_file_name, "wb"))

        config = RagConfig(
            retrieval_vector_size=self.retrieval_vector_size,
            question_encoder=DPRConfig().to_dict(),
            generator=BartConfig().to_dict(),
            index_name="legacy",
            index_path=self.tmpdirname,
            passages_path=self.tmpdirname,
        )
        retriever = RagRetriever(
            config,
            question_encoder_tokenizer=self.get_dpr_tokenizer(),
            generator_tokenizer=self.get_bart_tokenizer())
        return retriever