def get_dummy_custom_hf_index_pytorch_retriever(self, init_retrieval: bool, from_disk: bool, port=12345): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="custom", ) if from_disk: config.passages_path = os.path.join(self.tmpdirname, "dataset") config.index_path = os.path.join(self.tmpdirname, "index.faiss") dataset.get_index("embeddings").save( os.path.join(self.tmpdirname, "index.faiss")) dataset.drop_index("embeddings") dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) del dataset retriever = RagPyTorchDistributedRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) else: retriever = RagPyTorchDistributedRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), index=CustomHFIndex(config.retrieval_vector_size, dataset), ) if init_retrieval: retriever.init_retrieval(port) return retriever
def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool, from_disk: bool): # Have to run in local mode because sys.path modifications at top of # file are not propogated to remote workers. # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder ray.init(local_mode=True) dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="custom", ) remote_cls = ray.remote(RayRetriever) workers = [remote_cls.remote() for _ in range(1)] if from_disk: config.passages_path = os.path.join(self.tmpdirname, "dataset") config.index_path = os.path.join(self.tmpdirname, "index.faiss") dataset.get_index("embeddings").save( os.path.join(self.tmpdirname, "index.faiss")) dataset.drop_index("embeddings") dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) del dataset retriever = RagRayDistributedRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), retrieval_workers=workers, index=CustomHFIndex.load_from_disk( vector_size=config.retrieval_vector_size, dataset_path=config.passages_path, index_path=config.index_path, ), ) else: retriever = RagRayDistributedRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), retrieval_workers=workers, index=CustomHFIndex(config.retrieval_vector_size, dataset), ) if init_retrieval: retriever.init_retrieval() return retriever