def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): config_path = os.path.abspath(bert_config_file) tf_path = os.path.abspath(tf_checkpoint_path) print("Converting TensorFlow checkpoint from {} with config at {}".format( tf_path, config_path)) # Load weights from TF model init_vars = tf.train.list_variables(tf_path) names = [] arrays = [] for name, shape in init_vars: print("Loading TF weight {} with shape {}".format(name, shape)) array = tf.train.load_variable(tf_path, name) names.append(name) arrays.append(array) # Initialise PyTorch model config = BertConfig.from_json_file(bert_config_file) print("Building PyTorch model from configuration: {}".format(str(config))) model = BertForPreTraining(config) for name, array in zip(names, arrays): name = name.split('/') # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model if any(n in ["adam_v", "adam_m"] for n in name): print("Skipping {}".format("/".join(name))) continue pointer = model for m_name in name: if re.fullmatch(r'[A-Za-z]+_\d+', m_name): l = re.split(r'_(\d+)', m_name) else: l = [m_name] if l[0] == 'kernel' or l[0] == 'gamma': pointer = getattr(pointer, 'weight') elif l[0] == 'output_bias' or l[0] == 'beta': pointer = getattr(pointer, 'bias') elif l[0] == 'output_weights': pointer = getattr(pointer, 'weight') else: pointer = getattr(pointer, l[0]) if len(l) >= 2: num = int(l[1]) pointer = pointer[num] if m_name[-11:] == '_embeddings': pointer = getattr(pointer, 'weight') elif m_name == 'kernel': array = np.transpose(array) try: assert pointer.shape == array.shape except AssertionError as e: e.args += (pointer.shape, array.shape) raise print("Initialize PyTorch weight {}".format(name)) pointer.data = torch.from_numpy(array) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) torch.save(model.state_dict(), pytorch_dump_path)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): # Initialise PyTorch model config = BertConfig.from_json_file(bert_config_file) print("Building PyTorch model from configuration: {}".format(str(config))) model = BertForPreTraining(config) # Load weights from tf checkpoint load_tf_weights_in_bert(model, tf_checkpoint_path) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) torch.save(model.state_dict(), pytorch_dump_path)
def convert_all_tensorflow_bert_weights_to_pytorch(self, input_folder: str) -> None: """ Tensorflow to Pytorch weight conversion based on huggingface's library Parameters ---------- input_folder: `str`, required The folder containing the tensorflow files """ files = [e for e in os.listdir(input_folder) if os.path.isfile(os.path.join(input_folder, e))] folders = [os.path.join(input_folder, e) for e in os.listdir(input_folder) if os.path.isdir(os.path.join(input_folder, e))] flag = -4 for file in files: if file == 'vocab.txt' or \ file.endswith('.data-00000-of-00001') or \ file.endswith('.index') or \ file.endswith('.meta') or \ file.endswith('.json'): flag += 1 if file.endswith('.json'): config_file = file if flag > 0: assert type(config_file) == str, "no valid config file, but is attempting to convert" pytorch_path = os.path.join(input_folder, 'pytorch') tensorflow_path = os.path.join(input_folder, 'tensorflow') force_folder_to_exist(pytorch_path) force_folder_to_exist(tensorflow_path) os.system('mv ' + os.path.join(input_folder, '*.*') + ' ' + tensorflow_path) os.system('cp ' + os.path.join(tensorflow_path, '*.txt') + ' ' + pytorch_path) os.system('cp ' + os.path.join(tensorflow_path, '*.json') + ' ' + pytorch_path) config = BertConfig.from_json_file(os.path.join(tensorflow_path, config_file)) model = BertForPreTraining(config) load_tf_weights_in_bert(model=model, tf_checkpoint_path=os.path.join(tensorflow_path, 'bert_model.ckpt')) torch.save(model.state_dict(), os.path.join(pytorch_path, 'pytorch_model.bin')) else: for folder in folders: self.convert_all_tensorflow_bert_weights_to_pytorch(input_folder=folder)