def main(
    rag_example_args: "RagExampleArguments",
    processing_args: "ProcessingArguments",
    index_hnsw_args: "IndexHnswArguments",
):

    ######################################
    logger.info("Step 1 - Create the dataset")
    ######################################

    # The dataset needed for RAG must have three columns:
    # - title (string): title of the document
    # - text (string): text of a passage of the document
    # - embeddings (array of dimension d): DPR representation of the passage
    # Let's say you have documents in tab-separated csv files with columns "title" and "text"
    assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file"

    # You can load a Dataset object this way
    dataset = load_dataset(
        "csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
    )

    # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files

    # Then split the documents into passages of 100 words
    dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)

    # And compute the embeddings
    ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
    ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
    new_features = Features(
        {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
    )  # optional, save as float32 instead of float64 to save space
    dataset = dataset.map(
        partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
        batched=True,
        batch_size=processing_args.batch_size,
        features=new_features,
    )

    # And finally save your dataset
    passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset")
    dataset.save_to_disk(passages_path)
    # from datasets import load_from_disk
    # dataset = load_from_disk(passages_path)  # to reload the dataset

    ######################################
    logger.info("Step 2 - Index the dataset")
    ######################################

    # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
    index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
    dataset.add_faiss_index("embeddings", custom_index=index)

    # And save the index
    index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
    dataset.get_index("embeddings").save(index_path)
Example #2
0
 def __init__(self):
     self.ctx_encoder = DPRContextEncoder.from_pretrained(
         "facebook/dpr-ctx_encoder-multiset-base").to(Config.device)
     self.q_encoder = DPRQuestionEncoder.from_pretrained(
         "facebook/dpr-question_encoder-multiset-base").to(Config.device)
     self.ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
         "facebook/dpr-ctx_encoder-multiset-base")
     self.q_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained(
         "facebook/dpr-question_encoder-multiset-base")
Example #3
0
def generate_faiss_index_dataset(data, ctx_encoder_name, args, device):
    """
    Adapted from Huggingface example script at https://github.com/huggingface/transformers/blob/master/examples/research_projects/rag/use_own_knowledge_dataset.py
    """
    import faiss

    if isinstance(data, str):
        dataset = load_dataset("csv",
                               data_files=data,
                               delimiter="\t",
                               column_names=["title", "text"])
    else:
        dataset = HFDataset.from_pandas(data)

    dataset = dataset.map(
        partial(split_documents,
                split_text_n=args.split_text_n,
                split_text_character=args.split_text_character),
        batched=True,
        num_proc=args.process_count,
    )

    ctx_encoder = DPRContextEncoder.from_pretrained(ctx_encoder_name).to(
        device=device)
    ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
        ctx_encoder_name)

    new_features = Features({
        "text": Value("string"),
        "title": Value("string"),
        "embeddings": Sequence(Value("float32"))
    })  # optional, save as float32 instead of float64 to save space
    dataset = dataset.map(
        partial(embed,
                ctx_encoder=ctx_encoder,
                ctx_tokenizer=ctx_tokenizer,
                device=device),
        batched=True,
        batch_size=args.rag_embed_batch_size,
        features=new_features,
    )
    if isinstance(data, str):
        dataset = dataset["train"]

    if args.save_knowledge_dataset:
        output_dataset_directory = os.path.join(args.output_dir,
                                                "knowledge_dataset")
        os.makedirs(output_dataset_directory, exist_ok=True)
        dataset.save_to_disk(output_dataset_directory)

    index = faiss.IndexHNSWFlat(args.faiss_d, args.faiss_m,
                                faiss.METRIC_INNER_PRODUCT)
    dataset.add_faiss_index("embeddings", custom_index=index)

    return dataset
Example #4
0
def embed_update(ctx_encoder, total_processes, device, process_num, shard_dir,
                 csv_path):

    kb_dataset = load_dataset("csv",
                              data_files=[csv_path],
                              split="train",
                              delimiter="\t",
                              column_names=["title", "text"])
    kb_dataset = kb_dataset.map(
        split_documents, batched=True,
        num_proc=1)  # if you want you can load already splitted csv.
    kb_list = [
        kb_dataset.shard(total_processes, i, contiguous=True)
        for i in range(total_processes)
    ]
    data_shrad = kb_list[process_num]

    arrow_folder = "data_" + str(process_num)
    passages_path = os.path.join(shard_dir, arrow_folder)

    context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
        "facebook/dpr-ctx_encoder-multiset-base")
    ctx_encoder = ctx_encoder.to(device=device)

    def embed(documents: dict, ctx_encoder: DPRContextEncoder,
              ctx_tokenizer: DPRContextEncoderTokenizerFast, device) -> dict:
        """Compute the DPR embeddings of document passages"""
        input_ids = ctx_tokenizer(documents["title"],
                                  documents["text"],
                                  truncation=True,
                                  padding="longest",
                                  return_tensors="pt")["input_ids"]
        embeddings = ctx_encoder(input_ids.to(device=device),
                                 return_dict=True).pooler_output
        return {"embeddings": embeddings.detach().cpu().numpy()}

    new_features = Features({
        "text": Value("string"),
        "title": Value("string"),
        "embeddings": Sequence(Value("float32"))
    })  # optional, save as float32 instead of float64 to save space

    dataset = data_shrad.map(
        partial(embed,
                ctx_encoder=ctx_encoder,
                ctx_tokenizer=context_tokenizer,
                device=device),
        batched=True,
        batch_size=16,
        features=new_features,
    )
    dataset.save_to_disk(passages_path)
Example #5
0
    def load_dataset(self) -> None:
        logger.debug('loading rag dataset: %s', self.name)

        self.dataset = load_dataset('csv',
                                    data_files=[self.csv_path],
                                    split='train',
                                    delimiter=',',
                                    column_names=['title', 'text'])

        self.dataset = self.dataset.map(
            split_documents,
            batched=False,
            num_proc=6,
            batch_size=100,
        )

        ctx_encoder = DPRContextEncoder.from_pretrained(
            self.context_encoder).to(device=self.device)
        ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
            self.context_encoder)
        new_features = Features({
            'text': Value('string'),
            'title': Value('string'),
            'embeddings': Sequence(Value('float32'))
        })  # optional, save as float32 instead of float64 to save space

        self.dataset = self.dataset.map(
            partial(embed,
                    ctx_encoder=ctx_encoder,
                    ctx_tokenizer=ctx_tokenizer,
                    device=self.device),
            batched=True,
            batch_size=16,
            features=new_features,
        )

        self.dataset.save_to_disk(self.dataset_path)

        index = faiss.IndexHNSWFlat(768, 128, faiss.METRIC_INNER_PRODUCT)
        self.dataset.add_faiss_index('embeddings', custom_index=index)

        self.dataset.get_index('embeddings').save(self.faiss_path)
Example #6
0
    def test_embed(self):

        ctx_encoder = DPRContextEncoder.from_pretrained(
           'facebook/dpr-ctx_encoder-multiset-base'
        ).to(device='cpu')
        ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
           'facebook/dpr-ctx_encoder-multiset-base'
        )

        self.assertEqual(
            len(embed(
                {
                    'title': 'something',
                    'text': 'blah'
                },
                ctx_encoder,
                ctx_tokenizer,
                'cpu'
            )['embeddings'][0]),
            768
        )
Example #7
0
def main(
    rag_example_args: "RagExampleArguments",
    processing_args: "ProcessingArguments",
    index_hnsw_args: "IndexHnswArguments",
):

    ######################################
    logger.info("Step 1 - Create the dataset")
    ######################################

    # The dataset needed for RAG must have three columns:
    # - title (string): title of the document
    # - text (string): text of a passage of the document
    # - embeddings (array of dimension d): DPR representation of the passage

    # Let's say you have documents in tab-separated csv files with columns "title" and "text"
    assert os.path.isfile(
        rag_example_args.csv_path), "Please provide a valid path to a csv file"

    # You can load a Dataset object this way
    dataset = load_dataset("csv",
                           data_files=[rag_example_args.csv_path],
                           split="train",
                           delimiter="\t",
                           column_names=["title", "text"])

    # More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files

    # Then split the documents into passages of 100 words
    dataset = dataset.map(split_documents,
                          batched=True,
                          num_proc=processing_args.num_proc)

    # And compute the embeddings
    ctx_encoder = DPRContextEncoder.from_pretrained(
        rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
    ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
        rag_example_args.dpr_ctx_encoder_model_name)
    dataset = dataset.map(
        partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
        batched=True,
        batch_size=processing_args.batch_size,
    )

    # And finally save your dataset
    passages_path = os.path.join(rag_example_args.output_dir,
                                 "my_knowledge_dataset")
    dataset.save_to_disk(passages_path)
    # from datasets import load_from_disk
    # dataset = load_from_disk(passages_path)  # to reload the dataset

    ######################################
    logger.info("Step 2 - Index the dataset")
    ######################################

    # Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
    index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m,
                                faiss.METRIC_INNER_PRODUCT)
    dataset.add_faiss_index("embeddings", custom_index=index)

    # And save the index
    index_path = os.path.join(rag_example_args.output_dir,
                              "my_knowledge_dataset_hnsw_index.faiss")
    dataset.get_index("embeddings").save(index_path)
    # dataset.load_faiss_index("embeddings", index_path)  # to reload the index

    ######################################
    logger.info("Step 3 - Load RAG")
    ######################################

    # Easy way to load the model
    retriever = RagRetriever.from_pretrained(rag_example_args.rag_model_name,
                                             index_name="custom",
                                             indexed_dataset=dataset)
    model = RagSequenceForGeneration.from_pretrained(
        rag_example_args.rag_model_name, retriever=retriever)
    tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name)

    # For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
    # retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)

    ######################################
    logger.info("Step 4 - Have fun")
    ######################################

    question = rag_example_args.question or "What does Moses' rod turn into ?"
    input_ids = tokenizer.question_encoder(question,
                                           return_tensors="pt")["input_ids"]
    generated = model.generate(input_ids)
    generated_string = tokenizer.batch_decode(generated,
                                              skip_special_tokens=True)[0]
    logger.info("Q: " + question)
    logger.info("A: " + generated_string)
Example #8
0
    def __init__(self, hparams, **kwargs):
        # when loading from a pytorch lightning checkpoint, hparams are passed as dict
        if isinstance(hparams, dict):
            hparams = AttrDict(hparams)
        if hparams.model_type == "rag_sequence":
            self.model_class = RagSequenceForGeneration
        elif hparams.model_type == "rag_token":
            self.model_class = RagTokenForGeneration
        elif hparams.model_type == "bart":
            self.model_class = BartForConditionalGeneration
        else:
            self.model_class = T5ForConditionalGeneration
        self.is_rag_model = is_rag_model(hparams.model_type)

        config_class = RagConfig if self.is_rag_model else AutoConfig
        config = config_class.from_pretrained(hparams.model_name_or_path)

        # set retriever parameters
        config.index_name = hparams.index_name or config.index_name
        config.passages_path = hparams.passages_path or config.passages_path
        config.index_path = hparams.index_path or config.index_path
        config.use_dummy_dataset = hparams.use_dummy_dataset

        # set extra_model_params for generator configs and load_model
        extra_model_params = ("encoder_layerdrop", "decoder_layerdrop",
                              "attention_dropout", "dropout")
        if self.is_rag_model:
            if hparams.prefix is not None:
                config.generator.prefix = hparams.prefix
            config.label_smoothing = hparams.label_smoothing
            hparams, config.generator = set_extra_model_params(
                extra_model_params, hparams, config.generator)
            if hparams.distributed_retriever == "ray":
                # The Ray retriever needs the handles to the retriever actors.
                retriever = RagRayDistributedRetriever.from_pretrained(
                    hparams.model_name_or_path,
                    hparams.actor_handles,
                    config=config)

                if hparams.end2end:
                    ctx_encoder_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
                        "facebook/dpr-ctx_encoder-multiset-base")
                    retriever.set_ctx_encoder_tokenizer(ctx_encoder_tokenizer)
            else:
                logger.info(
                    "please use RAY as the distributed retrieval method")

            model = self.model_class.from_pretrained(
                hparams.model_name_or_path, config=config, retriever=retriever)
            if hparams.end2end:
                ctx_encoder = DPRContextEncoder.from_pretrained(
                    hparams.context_encoder_name)
                model.set_context_encoder_for_training(ctx_encoder)
            prefix = config.question_encoder.prefix
        else:
            if hparams.prefix is not None:
                config.prefix = hparams.prefix
            hparams, config = set_extra_model_params(extra_model_params,
                                                     hparams, config)
            model = self.model_class.from_pretrained(
                hparams.model_name_or_path, config=config)
            prefix = config.prefix

        tokenizer = (RagTokenizer.from_pretrained(hparams.model_name_or_path)
                     if self.is_rag_model else AutoTokenizer.from_pretrained(
                         hparams.model_name_or_path))

        self.config_dpr = DPRConfig.from_pretrained(
            hparams.context_encoder_name)
        self.custom_config = hparams
        self.context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
            hparams.context_encoder_name)

        super().__init__(hparams,
                         config=config,
                         tokenizer=tokenizer,
                         model=model)

        save_git_info(self.hparams.output_dir)
        self.output_dir = Path(self.hparams.output_dir)
        self.dpr_ctx_check_dir = str(Path(
            self.hparams.output_dir)) + "/dpr_ctx_checkpoint"
        self.metrics_save_path = Path(self.output_dir) / "metrics.json"
        self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
        pickle_save(self.hparams, self.hparams_save_path)
        self.step_count = 0
        self.metrics = defaultdict(list)

        self.dataset_kwargs: dict = dict(
            data_dir=self.hparams.data_dir,
            max_source_length=self.hparams.max_source_length,
            prefix=prefix or "",
        )
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "val": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {
            k: v if v >= 0 else None
            for k, v in n_observations_per_split.items()
        }
        self.target_lens = {
            "train": self.hparams.max_target_length,
            "val": self.hparams.val_max_target_length,
            "test": self.hparams.test_max_target_length,
        }
        assert self.target_lens["train"] <= self.target_lens[
            "val"], f"target_lens: {self.target_lens}"
        assert self.target_lens["train"] <= self.target_lens[
            "test"], f"target_lens: {self.target_lens}"

        self.hparams.git_sha = get_git_info()["repo_sha"]
        self.num_workers = hparams.num_workers
        self.distributed_port = self.hparams.distributed_port

        # For single GPU training, init_ddp_connection is not called.
        # So we need to initialize the retrievers here.
        if hparams.gpus <= 1:
            if hparams.distributed_retriever == "ray":
                self.model.retriever.init_retrieval()
            else:
                logger.info(
                    "please use RAY as the distributed retrieval method")

        self.distributed_retriever = hparams.distributed_retriever
Example #9
0
    def load(cls,
             pretrained_model_name_or_path,
             tokenizer_class=None,
             use_fast=False,
             **kwargs):
        """
        Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from
        `pretrained_model_name_or_path` or define it manually via `tokenizer_class`.

        :param pretrained_model_name_or_path:  The path of the saved pretrained model or its name (e.g. `bert-base-uncased`)
        :type pretrained_model_name_or_path: str
        :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
        :type tokenizer_class: str
        :param use_fast: (Optional, False by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
            use the Python one (False).
            Only DistilBERT, BERT and Electra fast tokenizers are supported.
        :type use_fast: bool
        :param kwargs:
        :return: Tokenizer
        """

        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
        # guess tokenizer type from name
        if tokenizer_class is None:
            if "albert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "AlbertTokenizer"
            elif "xlm-roberta" in pretrained_model_name_or_path.lower():
                tokenizer_class = "XLMRobertaTokenizer"
            elif "roberta" in pretrained_model_name_or_path.lower():
                tokenizer_class = "RobertaTokenizer"
            elif 'codebert' in pretrained_model_name_or_path.lower():
                if "mlm" in pretrained_model_name_or_path.lower():
                    raise NotImplementedError(
                        "MLM part of codebert is currently not supported in FARM"
                    )
                else:
                    tokenizer_class = "RobertaTokenizer"
            elif "camembert" in pretrained_model_name_or_path.lower(
            ) or "umberto" in pretrained_model_name_or_path:
                tokenizer_class = "CamembertTokenizer"
            elif "distilbert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "DistilBertTokenizer"
            elif "bert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "BertTokenizer"
            elif "xlnet" in pretrained_model_name_or_path.lower():
                tokenizer_class = "XLNetTokenizer"
            elif "electra" in pretrained_model_name_or_path.lower():
                tokenizer_class = "ElectraTokenizer"
            elif "word2vec" in pretrained_model_name_or_path.lower() or \
                    "glove" in pretrained_model_name_or_path.lower() or \
                    "fasttext" in pretrained_model_name_or_path.lower():
                tokenizer_class = "EmbeddingTokenizer"
            elif "minilm" in pretrained_model_name_or_path.lower():
                tokenizer_class = "BertTokenizer"
            elif "dpr-question_encoder" in pretrained_model_name_or_path.lower(
            ):
                tokenizer_class = "DPRQuestionEncoderTokenizer"
            elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower():
                tokenizer_class = "DPRContextEncoderTokenizer"
            else:
                raise ValueError(
                    f"Could not infer tokenizer_class from name '{pretrained_model_name_or_path}'. Set "
                    f"arg `tokenizer_class` in Tokenizer.load() to one of: AlbertTokenizer, "
                    f"XLMRobertaTokenizer, RobertaTokenizer, DistilBertTokenizer, BertTokenizer, or "
                    f"XLNetTokenizer.")
            logger.info(f"Loading tokenizer of type '{tokenizer_class}'")
        # return appropriate tokenizer object
        ret = None
        if tokenizer_class == "AlbertTokenizer":
            if use_fast:
                logger.error(
                    'AlbertTokenizerFast is not supported! Using AlbertTokenizer instead.'
                )
                ret = AlbertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
            else:
                ret = AlbertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif tokenizer_class == "XLMRobertaTokenizer":
            if use_fast:
                logger.error(
                    'XLMRobertaTokenizerFast is not supported! Using XLMRobertaTokenizer instead.'
                )
                ret = XLMRobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = XLMRobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "RobertaTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                logger.error(
                    'RobertaTokenizerFast is not supported! Using RobertaTokenizer instead.'
                )
                ret = RobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = RobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "DistilBertTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                ret = DistilBertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DistilBertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "BertTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                ret = BertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = BertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "XLNetTokenizer":
            if use_fast:
                logger.error(
                    'XLNetTokenizerFast is not supported! Using XLNetTokenizer instead.'
                )
                ret = XLNetTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
            else:
                ret = XLNetTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif "ElectraTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                ret = ElectraTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = ElectraTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "EmbeddingTokenizer":
            if use_fast:
                logger.error(
                    'EmbeddingTokenizerFast is not supported! Using EmbeddingTokenizer instead.'
                )
                ret = EmbeddingTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = EmbeddingTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "CamembertTokenizer":
            if use_fast:
                logger.error(
                    'CamembertTokenizerFast is not supported! Using CamembertTokenizer instead.'
                )
                ret = CamembertTokenizer._from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = CamembertTokenizer._from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "DPRQuestionEncoderTokenizer" or tokenizer_class == "DPRQuestionEncoderTokenizerFast":
            if use_fast or tokenizer_class == "DPRQuestionEncoderTokenizerFast":
                ret = DPRQuestionEncoderTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DPRQuestionEncoderTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "DPRContextEncoderTokenizer" or tokenizer_class == "DPRContextEncoderTokenizerFast":
            if use_fast or tokenizer_class == "DPRContextEncoderTokenizerFast":
                ret = DPRContextEncoderTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DPRContextEncoderTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        if ret is None:
            raise Exception("Unable to load tokenizer")
        else:
            return ret
Example #10
0
    def load(cls,
             pretrained_model_name_or_path,
             revision=None,
             tokenizer_class=None,
             use_fast=True,
             **kwargs):
        """
        Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from
        model config or define it manually via `tokenizer_class`.

        :param pretrained_model_name_or_path:  The path of the saved pretrained model or its name (e.g. `bert-base-uncased`)
        :type pretrained_model_name_or_path: str
        :param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
        :type revision: str
        :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
        :type tokenizer_class: str
        :param use_fast: (Optional, False by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
            use the Python one (False).
            Only DistilBERT, BERT and Electra fast tokenizers are supported.
        :type use_fast: bool
        :param kwargs:
        :return: Tokenizer
        """
        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
        kwargs["revision"] = revision

        if tokenizer_class is None:
            tokenizer_class = cls._infer_tokenizer_class(
                pretrained_model_name_or_path)

        logger.info(f"Loading tokenizer of type '{tokenizer_class}'")
        # return appropriate tokenizer object
        ret = None
        if "AlbertTokenizer" in tokenizer_class:
            if use_fast:
                ret = AlbertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
            else:
                ret = AlbertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif "XLMRobertaTokenizer" in tokenizer_class:
            if use_fast:
                ret = XLMRobertaTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = XLMRobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "RobertaTokenizer" in tokenizer_class:
            if use_fast:
                ret = RobertaTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = RobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "DistilBertTokenizer" in tokenizer_class:
            if use_fast:
                ret = DistilBertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DistilBertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "BertTokenizer" in tokenizer_class:
            if use_fast:
                ret = BertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = BertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "XLNetTokenizer" in tokenizer_class:
            if use_fast:
                ret = XLNetTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
            else:
                ret = XLNetTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif "ElectraTokenizer" in tokenizer_class:
            if use_fast:
                ret = ElectraTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = ElectraTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "EmbeddingTokenizer":
            if use_fast:
                logger.error(
                    'EmbeddingTokenizerFast is not supported! Using EmbeddingTokenizer instead.'
                )
                ret = EmbeddingTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = EmbeddingTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "CamembertTokenizer" in tokenizer_class:
            if use_fast:
                ret = CamembertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = CamembertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "DPRQuestionEncoderTokenizer" in tokenizer_class:
            if use_fast:
                ret = DPRQuestionEncoderTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DPRQuestionEncoderTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "DPRContextEncoderTokenizer" in tokenizer_class:
            if use_fast:
                ret = DPRContextEncoderTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DPRContextEncoderTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        if ret is None:
            raise Exception("Unable to load tokenizer")
        else:
            return ret