예제 #1
0
 def get_dummy_custom_hf_index_retriever(self, from_disk: bool):
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(
             os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     else:
         retriever = RagRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     return retriever
예제 #2
0
 def get_dummy_canonical_hf_index_retriever(self):
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     with patch("transformers.models.rag.retrieval_rag.load_dataset"
                ) as mock_load_dataset:
         mock_load_dataset.return_value = dataset
         retriever = RagRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     return retriever
예제 #3
0
    def get_dummy_legacy_index_retriever(self):
        dataset = Dataset.from_dict({
            "id": ["0", "1"],
            "text": ["foo", "bar"],
            "title": ["Foo", "Bar"],
            "embeddings": [
                np.ones(self.retrieval_vector_size + 1),
                2 * np.ones(self.retrieval_vector_size + 1)
            ],
        })
        dataset.add_faiss_index("embeddings",
                                string_factory="Flat",
                                metric_type=faiss.METRIC_INNER_PRODUCT)

        index_file_name = os.path.join(
            self.tmpdirname, "hf_bert_base.hnswSQ8_correct_phi_128.c_index")
        dataset.save_faiss_index("embeddings", index_file_name + ".index.dpr")
        pickle.dump(dataset["id"],
                    open(index_file_name + ".index_meta.dpr", "wb"))

        passages_file_name = os.path.join(self.tmpdirname, "psgs_w100.tsv.pkl")
        passages = {
            sample["id"]: [sample["text"], sample["title"]]
            for sample in dataset
        }
        pickle.dump(passages, open(passages_file_name, "wb"))

        config = RagConfig(
            retrieval_vector_size=self.retrieval_vector_size,
            question_encoder=DPRConfig().to_dict(),
            generator=BartConfig().to_dict(),
            index_name="legacy",
            index_path=self.tmpdirname,
        )
        retriever = RagRetriever(
            config,
            question_encoder_tokenizer=self.get_dpr_tokenizer(),
            generator_tokenizer=self.get_bart_tokenizer())
        return retriever