Example #1
0
import tensorflow as tf
import torch
from transformers import convert_bert_original_tf_checkpoint_to_pytorch as ctp


BERT_BASE_DIR = "uncased_L-12_H-768_A-12/"
tf_checkpoint_path = BERT_BASE_DIR + "bert_model.ckpt"
bert_config_file = BERT_BASE_DIR + "bert_config.json"
pytorch_dump_path = BERT_BASE_DIR + "pytorch_model.bin"


ctp.convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path)
Example #2
0
    def run(self):
        if self._model_type == "bert":
            try:
                from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
                    convert_tf_checkpoint_to_pytorch,
                )
            except ImportError:
                msg = (
                    "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions."
                )
                raise ImportError(msg)

            convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
        elif self._model_type == "gpt":
            from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
                convert_openai_checkpoint_to_pytorch,
            )

            convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
        elif self._model_type == "transfo_xl":
            try:
                from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
                    convert_transfo_xl_checkpoint_to_pytorch,
                )
            except ImportError:
                msg = (
                    "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions."
                )
                raise ImportError(msg)

            if "ckpt" in self._tf_checkpoint.lower():
                TF_CHECKPOINT = self._tf_checkpoint
                TF_DATASET_FILE = ""
            else:
                TF_DATASET_FILE = self._tf_checkpoint
                TF_CHECKPOINT = ""
            convert_transfo_xl_checkpoint_to_pytorch(
                TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
            )
        elif self._model_type == "gpt2":
            try:
                from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
                    convert_gpt2_checkpoint_to_pytorch,
                )
            except ImportError:
                msg = (
                    "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions."
                )
                raise ImportError(msg)

            convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
        elif self._model_type == "xlnet":
            try:
                from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
                    convert_xlnet_checkpoint_to_pytorch,
                )
            except ImportError:
                msg = (
                    "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
                    "In that case, it requires TensorFlow to be installed. Please see "
                    "https://www.tensorflow.org/install/ for installation instructions."
                )
                raise ImportError(msg)

            convert_xlnet_checkpoint_to_pytorch(
                self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
            )
        elif self._model_type == "xlm":
            from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
                convert_xlm_checkpoint_to_pytorch,
            )

            convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
        elif self._model_type == "t5":
            from transformers.convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch

            convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
        else:
            msg = "--model_type should be selected in the list " "[bert, gpt, gpt2, transfo_xl, xlnet, xlm, t5]"
            raise ValueError(msg)
Example #3
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @File    : conver.py

import transformers.convert_bert_original_tf_checkpoint_to_pytorch as con

tf_wobert_path = '/Users/jiang/Documents/pre_train_models/chinese_wobert_L-12_H-768_A-12'

con.convert_tf_checkpoint_to_pytorch(f'{tf_wobert_path}/bert_model.ckpt',
                                     f'{tf_wobert_path}/bert_config.json',
                                     f'{tf_wobert_path}/pytorch_bert.bin')