Beispiel #1
0
 def get_dummy_custom_hf_index_retriever(self, from_disk: bool):
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(
             os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
         )
     else:
         retriever = RagRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     return retriever
 def from_pretrained(cls,
                     retriever_name_or_path,
                     actor_handles,
                     indexed_dataset=None,
                     **kwargs):
     requires_datasets(cls)
     requires_faiss(cls)
     config = kwargs.pop("config", None) or RagConfig.from_pretrained(
         retriever_name_or_path, **kwargs)
     rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path,
                                                  config=config)
     question_encoder_tokenizer = rag_tokenizer.question_encoder
     generator_tokenizer = rag_tokenizer.generator
     if indexed_dataset is not None:
         config.index_name = "custom"
         index = CustomHFIndex(config.retrieval_vector_size,
                               indexed_dataset)
     else:
         index = cls._build_index(config)
     return cls(
         config,
         question_encoder_tokenizer=question_encoder_tokenizer,
         generator_tokenizer=generator_tokenizer,
         retrieval_workers=actor_handles,
         index=index,
     )
 def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool,
                                             from_disk: bool):
     # Have to run in local mode because sys.path modifications at top of
     # file are not propogated to remote workers.
     # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
     ray.init(local_mode=True)
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     remote_cls = ray.remote(RayRetriever)
     workers = [remote_cls.remote() for _ in range(1)]
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(
             os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
             index=CustomHFIndex.load_from_disk(
                 vector_size=config.retrieval_vector_size,
                 dataset_path=config.passages_path,
                 index_path=config.index_path,
             ),
         )
     else:
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     if init_retrieval:
         retriever.init_retrieval()
     return retriever