def get_dummy_ray_distributed_retriever(
         self, init_retrieval: bool) -> RagRayDistributedRetriever:
     # Have to run in local mode because sys.path modifications at top of
     # file are not propogated to remote workers.
     # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
     ray.init(local_mode=True)
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
     )
     remote_cls = ray.remote(RayRetriever)
     workers = [remote_cls.remote() for _ in range(1)]
     with patch("transformers.models.rag.retrieval_rag.load_dataset"
                ) as mock_load_dataset:
         mock_load_dataset.return_value = self.get_dummy_dataset()
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
         )
         if init_retrieval:
             retriever.init_retrieval()
     return retriever
 def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool,
                                             from_disk: bool):
     # Have to run in local mode because sys.path modifications at top of
     # file are not propogated to remote workers.
     # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
     ray.init(local_mode=True)
     dataset = self.get_dummy_dataset()
     config = RagConfig(
         retrieval_vector_size=self.retrieval_vector_size,
         question_encoder=DPRConfig().to_dict(),
         generator=BartConfig().to_dict(),
         index_name="custom",
     )
     remote_cls = ray.remote(RayRetriever)
     workers = [remote_cls.remote() for _ in range(1)]
     if from_disk:
         config.passages_path = os.path.join(self.tmpdirname, "dataset")
         config.index_path = os.path.join(self.tmpdirname, "index.faiss")
         dataset.get_index("embeddings").save(
             os.path.join(self.tmpdirname, "index.faiss"))
         dataset.drop_index("embeddings")
         dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
         del dataset
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
             index=CustomHFIndex.load_from_disk(
                 vector_size=config.retrieval_vector_size,
                 dataset_path=config.passages_path,
                 index_path=config.index_path,
             ),
         )
     else:
         retriever = RagRayDistributedRetriever(
             config,
             question_encoder_tokenizer=self.get_dpr_tokenizer(),
             generator_tokenizer=self.get_bart_tokenizer(),
             retrieval_workers=workers,
             index=CustomHFIndex(config.retrieval_vector_size, dataset),
         )
     if init_retrieval:
         retriever.init_retrieval()
     return retriever