def main( rag_example_args: "RagExampleArguments", processing_args: "ProcessingArguments", index_hnsw_args: "IndexHnswArguments", ): ###################################### logger.info("Step 1 - Create the dataset") ###################################### # The dataset needed for RAG must have three columns: # - title (string): title of the document # - text (string): text of a passage of the document # - embeddings (array of dimension d): DPR representation of the passage # Let's say you have documents in tab-separated csv files with columns "title" and "text" assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file" # You can load a Dataset object this way dataset = load_dataset( "csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"] ) # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files # Then split the documents into passages of 100 words dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc) # And compute the embeddings ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device) ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name) new_features = Features( {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))} ) # optional, save as float32 instead of float64 to save space dataset = dataset.map( partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer), batched=True, batch_size=processing_args.batch_size, features=new_features, ) # And finally save your dataset passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset") dataset.save_to_disk(passages_path) # from datasets import load_from_disk # dataset = load_from_disk(passages_path) # to reload the dataset ###################################### logger.info("Step 2 - Index the dataset") ###################################### # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT) dataset.add_faiss_index("embeddings", custom_index=index) # And save the index index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss") dataset.get_index("embeddings").save(index_path)
def __init__(self): self.ctx_encoder = DPRContextEncoder.from_pretrained( "facebook/dpr-ctx_encoder-multiset-base").to(Config.device) self.q_encoder = DPRQuestionEncoder.from_pretrained( "facebook/dpr-question_encoder-multiset-base").to(Config.device) self.ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( "facebook/dpr-ctx_encoder-multiset-base") self.q_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained( "facebook/dpr-question_encoder-multiset-base")
def generate_faiss_index_dataset(data, ctx_encoder_name, args, device): """ Adapted from Huggingface example script at https://github.com/huggingface/transformers/blob/master/examples/research_projects/rag/use_own_knowledge_dataset.py """ import faiss if isinstance(data, str): dataset = load_dataset("csv", data_files=data, delimiter="\t", column_names=["title", "text"]) else: dataset = HFDataset.from_pandas(data) dataset = dataset.map( partial(split_documents, split_text_n=args.split_text_n, split_text_character=args.split_text_character), batched=True, num_proc=args.process_count, ) ctx_encoder = DPRContextEncoder.from_pretrained(ctx_encoder_name).to( device=device) ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( ctx_encoder_name) new_features = Features({ "text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32")) }) # optional, save as float32 instead of float64 to save space dataset = dataset.map( partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer, device=device), batched=True, batch_size=args.rag_embed_batch_size, features=new_features, ) if isinstance(data, str): dataset = dataset["train"] if args.save_knowledge_dataset: output_dataset_directory = os.path.join(args.output_dir, "knowledge_dataset") os.makedirs(output_dataset_directory, exist_ok=True) dataset.save_to_disk(output_dataset_directory) index = faiss.IndexHNSWFlat(args.faiss_d, args.faiss_m, faiss.METRIC_INNER_PRODUCT) dataset.add_faiss_index("embeddings", custom_index=index) return dataset
def embed_update(ctx_encoder, total_processes, device, process_num, shard_dir, csv_path): kb_dataset = load_dataset("csv", data_files=[csv_path], split="train", delimiter="\t", column_names=["title", "text"]) kb_dataset = kb_dataset.map( split_documents, batched=True, num_proc=1) # if you want you can load already splitted csv. kb_list = [ kb_dataset.shard(total_processes, i, contiguous=True) for i in range(total_processes) ] data_shrad = kb_list[process_num] arrow_folder = "data_" + str(process_num) passages_path = os.path.join(shard_dir, arrow_folder) context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( "facebook/dpr-ctx_encoder-multiset-base") ctx_encoder = ctx_encoder.to(device=device) def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast, device) -> dict: """Compute the DPR embeddings of document passages""" input_ids = ctx_tokenizer(documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt")["input_ids"] embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output return {"embeddings": embeddings.detach().cpu().numpy()} new_features = Features({ "text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32")) }) # optional, save as float32 instead of float64 to save space dataset = data_shrad.map( partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=context_tokenizer, device=device), batched=True, batch_size=16, features=new_features, ) dataset.save_to_disk(passages_path)
def load_dataset(self) -> None: logger.debug('loading rag dataset: %s', self.name) self.dataset = load_dataset('csv', data_files=[self.csv_path], split='train', delimiter=',', column_names=['title', 'text']) self.dataset = self.dataset.map( split_documents, batched=False, num_proc=6, batch_size=100, ) ctx_encoder = DPRContextEncoder.from_pretrained( self.context_encoder).to(device=self.device) ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( self.context_encoder) new_features = Features({ 'text': Value('string'), 'title': Value('string'), 'embeddings': Sequence(Value('float32')) }) # optional, save as float32 instead of float64 to save space self.dataset = self.dataset.map( partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer, device=self.device), batched=True, batch_size=16, features=new_features, ) self.dataset.save_to_disk(self.dataset_path) index = faiss.IndexHNSWFlat(768, 128, faiss.METRIC_INNER_PRODUCT) self.dataset.add_faiss_index('embeddings', custom_index=index) self.dataset.get_index('embeddings').save(self.faiss_path)
def test_embed(self): ctx_encoder = DPRContextEncoder.from_pretrained( 'facebook/dpr-ctx_encoder-multiset-base' ).to(device='cpu') ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( 'facebook/dpr-ctx_encoder-multiset-base' ) self.assertEqual( len(embed( { 'title': 'something', 'text': 'blah' }, ctx_encoder, ctx_tokenizer, 'cpu' )['embeddings'][0]), 768 )
def main( rag_example_args: "RagExampleArguments", processing_args: "ProcessingArguments", index_hnsw_args: "IndexHnswArguments", ): ###################################### logger.info("Step 1 - Create the dataset") ###################################### # The dataset needed for RAG must have three columns: # - title (string): title of the document # - text (string): text of a passage of the document # - embeddings (array of dimension d): DPR representation of the passage # Let's say you have documents in tab-separated csv files with columns "title" and "text" assert os.path.isfile( rag_example_args.csv_path), "Please provide a valid path to a csv file" # You can load a Dataset object this way dataset = load_dataset("csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]) # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files # Then split the documents into passages of 100 words dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc) # And compute the embeddings ctx_encoder = DPRContextEncoder.from_pretrained( rag_example_args.dpr_ctx_encoder_model_name).to(device=device) ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( rag_example_args.dpr_ctx_encoder_model_name) dataset = dataset.map( partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer), batched=True, batch_size=processing_args.batch_size, ) # And finally save your dataset passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset") dataset.save_to_disk(passages_path) # from datasets import load_from_disk # dataset = load_from_disk(passages_path) # to reload the dataset ###################################### logger.info("Step 2 - Index the dataset") ###################################### # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT) dataset.add_faiss_index("embeddings", custom_index=index) # And save the index index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss") dataset.get_index("embeddings").save(index_path) # dataset.load_faiss_index("embeddings", index_path) # to reload the index ###################################### logger.info("Step 3 - Load RAG") ###################################### # Easy way to load the model retriever = RagRetriever.from_pretrained(rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset) model = RagSequenceForGeneration.from_pretrained( rag_example_args.rag_model_name, retriever=retriever) tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name) # For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately. # retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path) ###################################### logger.info("Step 4 - Have fun") ###################################### question = rag_example_args.question or "What does Moses' rod turn into ?" input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"] generated = model.generate(input_ids) generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0] logger.info("Q: " + question) logger.info("A: " + generated_string)
def __init__(self, hparams, **kwargs): # when loading from a pytorch lightning checkpoint, hparams are passed as dict if isinstance(hparams, dict): hparams = AttrDict(hparams) if hparams.model_type == "rag_sequence": self.model_class = RagSequenceForGeneration elif hparams.model_type == "rag_token": self.model_class = RagTokenForGeneration elif hparams.model_type == "bart": self.model_class = BartForConditionalGeneration else: self.model_class = T5ForConditionalGeneration self.is_rag_model = is_rag_model(hparams.model_type) config_class = RagConfig if self.is_rag_model else AutoConfig config = config_class.from_pretrained(hparams.model_name_or_path) # set retriever parameters config.index_name = hparams.index_name or config.index_name config.passages_path = hparams.passages_path or config.passages_path config.index_path = hparams.index_path or config.index_path config.use_dummy_dataset = hparams.use_dummy_dataset # set extra_model_params for generator configs and load_model extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout") if self.is_rag_model: if hparams.prefix is not None: config.generator.prefix = hparams.prefix config.label_smoothing = hparams.label_smoothing hparams, config.generator = set_extra_model_params( extra_model_params, hparams, config.generator) if hparams.distributed_retriever == "ray": # The Ray retriever needs the handles to the retriever actors. retriever = RagRayDistributedRetriever.from_pretrained( hparams.model_name_or_path, hparams.actor_handles, config=config) if hparams.end2end: ctx_encoder_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( "facebook/dpr-ctx_encoder-multiset-base") retriever.set_ctx_encoder_tokenizer(ctx_encoder_tokenizer) else: logger.info( "please use RAY as the distributed retrieval method") model = self.model_class.from_pretrained( hparams.model_name_or_path, config=config, retriever=retriever) if hparams.end2end: ctx_encoder = DPRContextEncoder.from_pretrained( hparams.context_encoder_name) model.set_context_encoder_for_training(ctx_encoder) prefix = config.question_encoder.prefix else: if hparams.prefix is not None: config.prefix = hparams.prefix hparams, config = set_extra_model_params(extra_model_params, hparams, config) model = self.model_class.from_pretrained( hparams.model_name_or_path, config=config) prefix = config.prefix tokenizer = (RagTokenizer.from_pretrained(hparams.model_name_or_path) if self.is_rag_model else AutoTokenizer.from_pretrained( hparams.model_name_or_path)) self.config_dpr = DPRConfig.from_pretrained( hparams.context_encoder_name) self.custom_config = hparams self.context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained( hparams.context_encoder_name) super().__init__(hparams, config=config, tokenizer=tokenizer, model=model) save_git_info(self.hparams.output_dir) self.output_dir = Path(self.hparams.output_dir) self.dpr_ctx_check_dir = str(Path( self.hparams.output_dir)) + "/dpr_ctx_checkpoint" self.metrics_save_path = Path(self.output_dir) / "metrics.json" self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" pickle_save(self.hparams, self.hparams_save_path) self.step_count = 0 self.metrics = defaultdict(list) self.dataset_kwargs: dict = dict( data_dir=self.hparams.data_dir, max_source_length=self.hparams.max_source_length, prefix=prefix or "", ) n_observations_per_split = { "train": self.hparams.n_train, "val": self.hparams.n_val, "test": self.hparams.n_test, } self.n_obs = { k: v if v >= 0 else None for k, v in n_observations_per_split.items() } self.target_lens = { "train": self.hparams.max_target_length, "val": self.hparams.val_max_target_length, "test": self.hparams.test_max_target_length, } assert self.target_lens["train"] <= self.target_lens[ "val"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens[ "test"], f"target_lens: {self.target_lens}" self.hparams.git_sha = get_git_info()["repo_sha"] self.num_workers = hparams.num_workers self.distributed_port = self.hparams.distributed_port # For single GPU training, init_ddp_connection is not called. # So we need to initialize the retrievers here. if hparams.gpus <= 1: if hparams.distributed_retriever == "ray": self.model.retriever.init_retrieval() else: logger.info( "please use RAY as the distributed retrieval method") self.distributed_retriever = hparams.distributed_retriever
def load(cls, pretrained_model_name_or_path, tokenizer_class=None, use_fast=False, **kwargs): """ Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from `pretrained_model_name_or_path` or define it manually via `tokenizer_class`. :param pretrained_model_name_or_path: The path of the saved pretrained model or its name (e.g. `bert-base-uncased`) :type pretrained_model_name_or_path: str :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`) :type tokenizer_class: str :param use_fast: (Optional, False by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or use the Python one (False). Only DistilBERT, BERT and Electra fast tokenizers are supported. :type use_fast: bool :param kwargs: :return: Tokenizer """ pretrained_model_name_or_path = str(pretrained_model_name_or_path) # guess tokenizer type from name if tokenizer_class is None: if "albert" in pretrained_model_name_or_path.lower(): tokenizer_class = "AlbertTokenizer" elif "xlm-roberta" in pretrained_model_name_or_path.lower(): tokenizer_class = "XLMRobertaTokenizer" elif "roberta" in pretrained_model_name_or_path.lower(): tokenizer_class = "RobertaTokenizer" elif 'codebert' in pretrained_model_name_or_path.lower(): if "mlm" in pretrained_model_name_or_path.lower(): raise NotImplementedError( "MLM part of codebert is currently not supported in FARM" ) else: tokenizer_class = "RobertaTokenizer" elif "camembert" in pretrained_model_name_or_path.lower( ) or "umberto" in pretrained_model_name_or_path: tokenizer_class = "CamembertTokenizer" elif "distilbert" in pretrained_model_name_or_path.lower(): tokenizer_class = "DistilBertTokenizer" elif "bert" in pretrained_model_name_or_path.lower(): tokenizer_class = "BertTokenizer" elif "xlnet" in pretrained_model_name_or_path.lower(): tokenizer_class = "XLNetTokenizer" elif "electra" in pretrained_model_name_or_path.lower(): tokenizer_class = "ElectraTokenizer" elif "word2vec" in pretrained_model_name_or_path.lower() or \ "glove" in pretrained_model_name_or_path.lower() or \ "fasttext" in pretrained_model_name_or_path.lower(): tokenizer_class = "EmbeddingTokenizer" elif "minilm" in pretrained_model_name_or_path.lower(): tokenizer_class = "BertTokenizer" elif "dpr-question_encoder" in pretrained_model_name_or_path.lower( ): tokenizer_class = "DPRQuestionEncoderTokenizer" elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower(): tokenizer_class = "DPRContextEncoderTokenizer" else: raise ValueError( f"Could not infer tokenizer_class from name '{pretrained_model_name_or_path}'. Set " f"arg `tokenizer_class` in Tokenizer.load() to one of: AlbertTokenizer, " f"XLMRobertaTokenizer, RobertaTokenizer, DistilBertTokenizer, BertTokenizer, or " f"XLNetTokenizer.") logger.info(f"Loading tokenizer of type '{tokenizer_class}'") # return appropriate tokenizer object ret = None if tokenizer_class == "AlbertTokenizer": if use_fast: logger.error( 'AlbertTokenizerFast is not supported! Using AlbertTokenizer instead.' ) ret = AlbertTokenizer.from_pretrained( pretrained_model_name_or_path, keep_accents=True, **kwargs) else: ret = AlbertTokenizer.from_pretrained( pretrained_model_name_or_path, keep_accents=True, **kwargs) elif tokenizer_class == "XLMRobertaTokenizer": if use_fast: logger.error( 'XLMRobertaTokenizerFast is not supported! Using XLMRobertaTokenizer instead.' ) ret = XLMRobertaTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = XLMRobertaTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif "RobertaTokenizer" in tokenizer_class: # because it also might be fast tokekenizer we use "in" if use_fast: logger.error( 'RobertaTokenizerFast is not supported! Using RobertaTokenizer instead.' ) ret = RobertaTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = RobertaTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif "DistilBertTokenizer" in tokenizer_class: # because it also might be fast tokekenizer we use "in" if use_fast: ret = DistilBertTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = DistilBertTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif "BertTokenizer" in tokenizer_class: # because it also might be fast tokekenizer we use "in" if use_fast: ret = BertTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = BertTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif tokenizer_class == "XLNetTokenizer": if use_fast: logger.error( 'XLNetTokenizerFast is not supported! Using XLNetTokenizer instead.' ) ret = XLNetTokenizer.from_pretrained( pretrained_model_name_or_path, keep_accents=True, **kwargs) else: ret = XLNetTokenizer.from_pretrained( pretrained_model_name_or_path, keep_accents=True, **kwargs) elif "ElectraTokenizer" in tokenizer_class: # because it also might be fast tokekenizer we use "in" if use_fast: ret = ElectraTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = ElectraTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif tokenizer_class == "EmbeddingTokenizer": if use_fast: logger.error( 'EmbeddingTokenizerFast is not supported! Using EmbeddingTokenizer instead.' ) ret = EmbeddingTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = EmbeddingTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif tokenizer_class == "CamembertTokenizer": if use_fast: logger.error( 'CamembertTokenizerFast is not supported! Using CamembertTokenizer instead.' ) ret = CamembertTokenizer._from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = CamembertTokenizer._from_pretrained( pretrained_model_name_or_path, **kwargs) elif tokenizer_class == "DPRQuestionEncoderTokenizer" or tokenizer_class == "DPRQuestionEncoderTokenizerFast": if use_fast or tokenizer_class == "DPRQuestionEncoderTokenizerFast": ret = DPRQuestionEncoderTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = DPRQuestionEncoderTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif tokenizer_class == "DPRContextEncoderTokenizer" or tokenizer_class == "DPRContextEncoderTokenizerFast": if use_fast or tokenizer_class == "DPRContextEncoderTokenizerFast": ret = DPRContextEncoderTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = DPRContextEncoderTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) if ret is None: raise Exception("Unable to load tokenizer") else: return ret
def load(cls, pretrained_model_name_or_path, revision=None, tokenizer_class=None, use_fast=True, **kwargs): """ Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from model config or define it manually via `tokenizer_class`. :param pretrained_model_name_or_path: The path of the saved pretrained model or its name (e.g. `bert-base-uncased`) :type pretrained_model_name_or_path: str :param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. :type revision: str :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`) :type tokenizer_class: str :param use_fast: (Optional, False by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or use the Python one (False). Only DistilBERT, BERT and Electra fast tokenizers are supported. :type use_fast: bool :param kwargs: :return: Tokenizer """ pretrained_model_name_or_path = str(pretrained_model_name_or_path) kwargs["revision"] = revision if tokenizer_class is None: tokenizer_class = cls._infer_tokenizer_class( pretrained_model_name_or_path) logger.info(f"Loading tokenizer of type '{tokenizer_class}'") # return appropriate tokenizer object ret = None if "AlbertTokenizer" in tokenizer_class: if use_fast: ret = AlbertTokenizerFast.from_pretrained( pretrained_model_name_or_path, keep_accents=True, **kwargs) else: ret = AlbertTokenizer.from_pretrained( pretrained_model_name_or_path, keep_accents=True, **kwargs) elif "XLMRobertaTokenizer" in tokenizer_class: if use_fast: ret = XLMRobertaTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = XLMRobertaTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif "RobertaTokenizer" in tokenizer_class: if use_fast: ret = RobertaTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = RobertaTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif "DistilBertTokenizer" in tokenizer_class: if use_fast: ret = DistilBertTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = DistilBertTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif "BertTokenizer" in tokenizer_class: if use_fast: ret = BertTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = BertTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif "XLNetTokenizer" in tokenizer_class: if use_fast: ret = XLNetTokenizerFast.from_pretrained( pretrained_model_name_or_path, keep_accents=True, **kwargs) else: ret = XLNetTokenizer.from_pretrained( pretrained_model_name_or_path, keep_accents=True, **kwargs) elif "ElectraTokenizer" in tokenizer_class: if use_fast: ret = ElectraTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = ElectraTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif tokenizer_class == "EmbeddingTokenizer": if use_fast: logger.error( 'EmbeddingTokenizerFast is not supported! Using EmbeddingTokenizer instead.' ) ret = EmbeddingTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = EmbeddingTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif "CamembertTokenizer" in tokenizer_class: if use_fast: ret = CamembertTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = CamembertTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif "DPRQuestionEncoderTokenizer" in tokenizer_class: if use_fast: ret = DPRQuestionEncoderTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = DPRQuestionEncoderTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) elif "DPRContextEncoderTokenizer" in tokenizer_class: if use_fast: ret = DPRContextEncoderTokenizerFast.from_pretrained( pretrained_model_name_or_path, **kwargs) else: ret = DPRContextEncoderTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs) if ret is None: raise Exception("Unable to load tokenizer") else: return ret