def get_dummy_custom_hf_index_retriever(self, from_disk: bool): dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="custom", ) if from_disk: config.passages_path = os.path.join(self.tmpdirname, "dataset") config.index_path = os.path.join(self.tmpdirname, "index.faiss") dataset.get_index("embeddings").save( os.path.join(self.tmpdirname, "index.faiss")) dataset.drop_index("embeddings") dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) del dataset retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), ) else: retriever = RagRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), index=CustomHFIndex(config.retrieval_vector_size, dataset), ) return retriever
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs): requires_datasets(cls) requires_faiss(cls) config = kwargs.pop("config", None) or RagConfig.from_pretrained( retriever_name_or_path, **kwargs) rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) question_encoder_tokenizer = rag_tokenizer.question_encoder generator_tokenizer = rag_tokenizer.generator if indexed_dataset is not None: config.index_name = "custom" index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset) else: index = cls._build_index(config) return cls( config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer, retrieval_workers=actor_handles, index=index, )
def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool, from_disk: bool): # Have to run in local mode because sys.path modifications at top of # file are not propogated to remote workers. # https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder ray.init(local_mode=True) dataset = self.get_dummy_dataset() config = RagConfig( retrieval_vector_size=self.retrieval_vector_size, question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict(), index_name="custom", ) remote_cls = ray.remote(RayRetriever) workers = [remote_cls.remote() for _ in range(1)] if from_disk: config.passages_path = os.path.join(self.tmpdirname, "dataset") config.index_path = os.path.join(self.tmpdirname, "index.faiss") dataset.get_index("embeddings").save( os.path.join(self.tmpdirname, "index.faiss")) dataset.drop_index("embeddings") dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset")) del dataset retriever = RagRayDistributedRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), retrieval_workers=workers, index=CustomHFIndex.load_from_disk( vector_size=config.retrieval_vector_size, dataset_path=config.passages_path, index_path=config.index_path, ), ) else: retriever = RagRayDistributedRetriever( config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer(), retrieval_workers=workers, index=CustomHFIndex(config.retrieval_vector_size, dataset), ) if init_retrieval: retriever.init_retrieval() return retriever