Beispiel #1
0
    def run(self):
        if self._model_type == "bert":
            try:
                from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
                    convert_tf_checkpoint_to_pytorch,
                )
            except ImportError:
                msg = (
                    "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions."
                )
                raise ImportError(msg)

            convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
        elif self._model_type == "gpt":
            from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
                convert_openai_checkpoint_to_pytorch,
            )

            convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
        elif self._model_type == "transfo_xl":
            try:
                from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
                    convert_transfo_xl_checkpoint_to_pytorch,
                )
            except ImportError:
                msg = (
                    "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions."
                )
                raise ImportError(msg)

            if "ckpt" in self._tf_checkpoint.lower():
                TF_CHECKPOINT = self._tf_checkpoint
                TF_DATASET_FILE = ""
            else:
                TF_DATASET_FILE = self._tf_checkpoint
                TF_CHECKPOINT = ""
            convert_transfo_xl_checkpoint_to_pytorch(
                TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
            )
        elif self._model_type == "gpt2":
            try:
                from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
                    convert_gpt2_checkpoint_to_pytorch,
                )
            except ImportError:
                msg = (
                    "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions."
                )
                raise ImportError(msg)

            convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
        elif self._model_type == "xlnet":
            try:
                from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
                    convert_xlnet_checkpoint_to_pytorch,
                )
            except ImportError:
                msg = (
                    "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions."
                )
                raise ImportError(msg)

            convert_xlnet_checkpoint_to_pytorch(
                self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
            )
        elif self._model_type == "xlm":
            from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
                convert_xlm_checkpoint_to_pytorch,
            )

            convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
        elif self._model_type == "t5":
            from transformers.convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch

            convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
        else:
            msg = "--model_type should be selected in the list " "[bert, gpt, gpt2, transfo_xl, xlnet, xlm, t5]"
            raise ValueError(msg)
Beispiel #2
0
    def __init__(
        self,
        model: str = None,
        config: Union[str, GPT2Config] = None,
        vocab_file: str = None,
        merges_file: str = None,
        cache_dir: str = "aitextgen",
        tf_gpt2: str = None,
        to_gpu: bool = False,
        to_fp16: bool = False,
        verbose: bool = False,
        torchscript: bool = False,
        ts_to_trace: bool = False,
        bos_token: str = None,
        eos_token: str = None,
        unk_token: str = None,
        **kwargs,
    ) -> None:

        if not verbose:
            for module in [
                    "transformers.file_utils",
                    "transformers.configuration_utils",
                    "transformers.tokenization_utils",
                    "filelock",
                    "transformers.modeling_gpt2",
            ]:
                logging.getLogger(module).setLevel(logging.WARN)
            logging.getLogger("transformers.modeling_utils").setLevel(
                logging.ERROR)

        if torchscript:
            assert model
            logger.info(f"Loading traced GPT-2 model from provided {model}.")
            if config is None:
                config = GPT2Config()
            self.torchscript = True
            self.model = GPT2LMHeadModel(config)

            # Transpose the traced model attributes to a GPT2LMHeadModel class
            # so it can inherit its functions
            pt_model = torch.jit.load(model)
            self.model.transformer = pt_model.transformer
            self.model.lm_head = pt_model.lm_head

        elif tf_gpt2:
            # Download + convert the TF weights if a PyTorch model has not been created
            if not os.path.isfile(
                    os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin")):
                assert tf_gpt2 in [
                    "124M",
                    "355M",
                    "774M",
                    "1558M",
                ], "Invalid TensorFlow GPT-2 model size."

                logger.info(
                    f"Downloading the {tf_gpt2} GPT-2 TensorFlow weights/config "
                    + "from Google's servers")

                download_gpt2(cache_dir, tf_gpt2)

                logger.info(
                    f"Converting the {tf_gpt2} GPT-2 TensorFlow weights to PyTorch."
                )

                config_path = os.path.join(cache_dir, tf_gpt2, "hparams.json")

                convert_gpt2_checkpoint_to_pytorch(
                    os.path.join(cache_dir, tf_gpt2),
                    config_path,
                    cache_dir,
                )

                os.rename(
                    os.path.join(cache_dir, f"pytorch_model.bin"),
                    os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin"),
                )

                os.rename(
                    os.path.join(cache_dir, f"config.json"),
                    os.path.join(cache_dir, f"config_{tf_gpt2}.json"),
                )

            logger.info(f"Loading {tf_gpt2} GPT-2 model from /{cache_dir}.")
            model = os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin")
            config = os.path.join(cache_dir, f"config_{tf_gpt2}.json")

            self.model = GPT2LMHeadModel.from_pretrained(model, config=config)

        elif model and os.path.exists(model):
            # A pytorch_model.bin (+ optional config/config.json) is provided
            logger.info(f"Loading GPT-2 model from provided {model}.")
            if config is None:
                config = GPT2Config()
            if ts_to_trace:
                config.torchscript = True
            self.model = GPT2LMHeadModel.from_pretrained(model, config=config)
        elif config:
            if ts_to_trace:
                config.torchscript = True
            # Manually construct a GPT-2 model from scratch
            logger.info("Constructing GPT-2 model from provided config.")
            self.model = AutoModelWithLMHead.from_config(config=config)
        else:
            # Download and cache model from Huggingface
            if os.path.isdir(cache_dir) and len(os.listdir(cache_dir)) > 0:
                logger.info(
                    f"Loading {model or 'gpt2'} model from /{cache_dir}.")
            else:
                logger.info(
                    f"Downloading {model or 'gpt2'} model to /{cache_dir}.")
            self.model = GPT2LMHeadModel.from_pretrained(
                model or "gpt2", cache_dir=cache_dir, torchscript=ts_to_trace)
            if model and "gpt2" not in model:
                logger.info(f"Using the tokenizer for {model}.")
                self.tokenizer = GPT2Tokenizer.from_pretrained(
                    model,
                    cache_dir=cache_dir,
                )

        if self.tokenizer is None:
            # Update tokenizer settings (if not set already)
            args = locals()
            custom_tokenizer = False
            for attr in [
                    "vocab_file",
                    "merges_file",
                    "bos_token",
                    "eos_token",
                    "unk_token",
            ]:
                if args[attr] is not None:
                    custom_tokenizer = True
                    setattr(self, attr, args[attr])

            if custom_tokenizer:
                logger.info("Using a custom tokenizer.")
            else:
                logger.info("Using the default GPT-2 Tokenizer.")

            self.tokenizer = GPT2Tokenizer(
                vocab_file=self.vocab_file,
                merges_file=self.merges_file,
                bos_token=self.bos_token,
                eos_token=self.eos_token,
                unk_token=self.unk_token,
                pad_token=self.pad_token,
            )

        if to_gpu:
            if to_fp16:
                self.to_fp16()
            self.to_gpu()
Beispiel #3
0
    def run(self):
        if self._model_type == "albert":
            try:
                from transformers.convert_albert_original_tf_checkpoint_to_pytorch import (
                    convert_tf_checkpoint_to_pytorch, )
            except ImportError:
                raise ImportError(IMPORT_ERROR_MESSAGE)

            convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config,
                                             self._pytorch_dump_output)
        elif self._model_type == "bert":
            try:
                from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
                    convert_tf_checkpoint_to_pytorch, )
            except ImportError:
                raise ImportError(IMPORT_ERROR_MESSAGE)

            convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config,
                                             self._pytorch_dump_output)
        elif self._model_type == "funnel":
            try:
                from transformers.convert_funnel_original_tf_checkpoint_to_pytorch import (
                    convert_tf_checkpoint_to_pytorch, )
            except ImportError:
                raise ImportError(IMPORT_ERROR_MESSAGE)

            convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config,
                                             self._pytorch_dump_output)
        elif self._model_type == "gpt":
            from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
                convert_openai_checkpoint_to_pytorch, )

            convert_openai_checkpoint_to_pytorch(self._tf_checkpoint,
                                                 self._config,
                                                 self._pytorch_dump_output)
        elif self._model_type == "transfo_xl":
            try:
                from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
                    convert_transfo_xl_checkpoint_to_pytorch, )
            except ImportError:
                raise ImportError(IMPORT_ERROR_MESSAGE)

            if "ckpt" in self._tf_checkpoint.lower():
                TF_CHECKPOINT = self._tf_checkpoint
                TF_DATASET_FILE = ""
            else:
                TF_DATASET_FILE = self._tf_checkpoint
                TF_CHECKPOINT = ""
            convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT,
                                                     self._config,
                                                     self._pytorch_dump_output,
                                                     TF_DATASET_FILE)
        elif self._model_type == "gpt2":
            try:
                from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
                    convert_gpt2_checkpoint_to_pytorch, )
            except ImportError:
                raise ImportError(IMPORT_ERROR_MESSAGE)

            convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint,
                                               self._config,
                                               self._pytorch_dump_output)
        elif self._model_type == "xlnet":
            try:
                from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
                    convert_xlnet_checkpoint_to_pytorch, )
            except ImportError:
                raise ImportError(IMPORT_ERROR_MESSAGE)

            convert_xlnet_checkpoint_to_pytorch(self._tf_checkpoint,
                                                self._config,
                                                self._pytorch_dump_output,
                                                self._finetuning_task_name)
        elif self._model_type == "xlm":
            from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
                convert_xlm_checkpoint_to_pytorch, )

            convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint,
                                              self._pytorch_dump_output)
        elif self._model_type == "lxmert":
            from transformers.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
                convert_lxmert_checkpoint_to_pytorch, )

            convert_lxmert_checkpoint_to_pytorch(self._tf_checkpoint,
                                                 self._pytorch_dump_output)
        else:
            raise ValueError(
                "--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm, lxmert]"
            )