예제 #1
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,
                                     pytorch_dump_path):
    config_path = os.path.abspath(bert_config_file)
    tf_path = os.path.abspath(tf_checkpoint_path)
    print("Converting TensorFlow checkpoint from {} with config at {}".format(
        tf_path, config_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
        print("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    for name, array in zip(names, arrays):
        name = name.split('/')
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
        if any(n in ["adam_v", "adam_m"] for n in name):
            print("Skipping {}".format("/".join(name)))
            continue
        pointer = model
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
                l = re.split(r'_(\d+)', m_name)
            else:
                l = [m_name]
            if l[0] == 'kernel' or l[0] == 'gamma':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'output_bias' or l[0] == 'beta':
                pointer = getattr(pointer, 'bias')
            elif l[0] == 'output_weights':
                pointer = getattr(pointer, 'weight')
            else:
                pointer = getattr(pointer, l[0])
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        if m_name[-11:] == '_embeddings':
            pointer = getattr(pointer, 'weight')
        elif m_name == 'kernel':
            array = np.transpose(array)
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        print("Initialize PyTorch weight {}".format(name))
        pointer.data = torch.from_numpy(array)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
예제 #2
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)
    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, tf_checkpoint_path)
    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
예제 #3
0
    def convert_all_tensorflow_bert_weights_to_pytorch(self, input_folder: str) -> None:
        """
        Tensorflow to Pytorch weight conversion based on huggingface's library

        Parameters
        ----------
        input_folder: `str`, required
            The folder containing the tensorflow files
        """
        files = [e for e in os.listdir(input_folder) if os.path.isfile(os.path.join(input_folder, e))]
        folders = [os.path.join(input_folder, e) for e in os.listdir(input_folder) if
                   os.path.isdir(os.path.join(input_folder, e))]

        flag = -4
        for file in files:
            if file == 'vocab.txt' or \
                    file.endswith('.data-00000-of-00001') or \
                    file.endswith('.index') or \
                    file.endswith('.meta') or \
                    file.endswith('.json'):
                flag += 1
                if file.endswith('.json'):
                    config_file = file

        if flag > 0:
            assert type(config_file) == str, "no valid config file, but is attempting to convert"
            pytorch_path = os.path.join(input_folder, 'pytorch')
            tensorflow_path = os.path.join(input_folder, 'tensorflow')

            force_folder_to_exist(pytorch_path)
            force_folder_to_exist(tensorflow_path)

            os.system('mv ' + os.path.join(input_folder, '*.*') + ' ' + tensorflow_path)
            os.system('cp ' + os.path.join(tensorflow_path, '*.txt') + ' ' + pytorch_path)
            os.system('cp ' + os.path.join(tensorflow_path, '*.json') + ' ' + pytorch_path)

            config = BertConfig.from_json_file(os.path.join(tensorflow_path, config_file))
            model = BertForPreTraining(config)
            load_tf_weights_in_bert(model=model, tf_checkpoint_path=os.path.join(tensorflow_path, 'bert_model.ckpt'))
            torch.save(model.state_dict(), os.path.join(pytorch_path, 'pytorch_model.bin'))

        else:
            for folder in folders:
                self.convert_all_tensorflow_bert_weights_to_pytorch(input_folder=folder)