Beispiel #1
0
 def get_dummy_custom_hf_index_retriever(self, from_disk: bool):
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(
             os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     else:
         retriever = RagRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     return retriever
 def get_dummy_hf_index_retriever(self):
     dataset = Dataset.from_dict({
         "id": ["0", "1"],
         "text": ["foo", "bar"],
         "title": ["Foo", "Bar"],
         "embeddings": [
             np.ones(self.retrieval_vector_size),
             2 * np.ones(self.retrieval_vector_size)
         ],
     })
     dataset.add_faiss_index("embeddings",
                             string_factory="Flat",
                             metric_type=faiss.METRIC_INNER_PRODUCT)
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     with patch("transformers.retrieval_rag.load_dataset"
                ) as mock_load_dataset:
         mock_load_dataset.return_value = dataset
         retriever = RagRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     return retriever
Beispiel #3
0
 def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
     retriever = self.get_dummy_legacy_index_retriever()
     with tempfile.TemporaryDirectory() as tmp_dirname:
         retriever.save_pretrained(tmp_dirname)
         retriever = RagRetriever.from_pretrained(tmp_dirname)
         self.assertIsInstance(retriever, RagRetriever)
         hidden_states = np.array([
             np.ones(self.retrieval_vector_size),
             -np.ones(self.retrieval_vector_size)
         ],
                                  dtype=np.float32)
         out = retriever.retrieve(hidden_states, n_docs=1)
         self.assertTrue(out is not None)
Beispiel #4
0
 def test_canonical_hf_index_retriever_save_and_from_pretrained(self):
     retriever = self.get_dummy_canonical_hf_index_retriever()
     with tempfile.TemporaryDirectory() as tmp_dirname:
         with patch("transformers.retrieval_rag.load_dataset"
                    ) as mock_load_dataset:
             mock_load_dataset.return_value = self.get_dummy_dataset()
             retriever.save_pretrained(tmp_dirname)
             retriever = RagRetriever.from_pretrained(tmp_dirname)
             self.assertIsInstance(retriever, RagRetriever)
             hidden_states = np.array([
                 np.ones(self.retrieval_vector_size),
                 -np.ones(self.retrieval_vector_size)
             ],
                                      dtype=np.float32)
             out = retriever.retrieve(hidden_states, n_docs=1)
             self.assertTrue(out is not None)
Beispiel #5
0
 def get_dummy_canonical_hf_index_retriever(self):
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     with patch("transformers.retrieval_rag.load_dataset"
                ) as mock_load_dataset:
         mock_load_dataset.return_value = dataset
         retriever = RagRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     return retriever
    def get_dummy_legacy_index_retriever(self):
        dataset = Dataset.from_dict({
            "id": ["0", "1"],
            "text": ["foo", "bar"],
            "title": ["Foo", "Bar"],
            "embeddings": [
                np.ones(self.retrieval_vector_size + 1),
                2 * np.ones(self.retrieval_vector_size + 1)
            ],
        })
        dataset.add_faiss_index("embeddings",
                                string_factory="Flat",
                                metric_type=faiss.METRIC_INNER_PRODUCT)

        index_file_name = os.path.join(
            self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
        dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
        pickle.dump(dataset["id"],
                    open(index_file_name + ".index_meta.dpr", "wb"))

        passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
        passages = {
            sample["id"]: [sample["text"], sample["title"]]
            for sample in dataset
        }
        pickle.dump(passages, open(passages_file_name, "wb"))

        config = RagConfig(
            retrieval_vector_size=self.retrieval_vector_size,
            question_encoder=DPRConfig().to_dict(),
            generator=BartConfig().to_dict(),
            index_name="legacy",
            index_path=self.tmpdirname,
            passages_path=self.tmpdirname,
        )
        retriever = RagRetriever(
            config,
            question_encoder_tokenizer=self.get_dpr_tokenizer(),
            generator_tokenizer=self.get_bart_tokenizer())
        return retriever