def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, full, gpt2_config_file, pytorch_dump_folder_path): #putting requirements here so users can see usage info before it errors out on missing modules from io import open from shutil import copyfile import logging logging.basicConfig(level=logging.INFO) from pathlib import Path import torch #WEIGHTS_NAME = "pytorch_model.bin" #CONFIG_NAME = "config.json" from transformers import ( CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2, ) gpt2_checkpoint_path = Path(gpt2_checkpoint_path) print(gpt2_checkpoint_path.name) if pytorch_dump_folder_path == '': prefix = '32BIT-' if full else '16BIT-' pytorch_dump_folder_path = 'pytorch-' + prefix + gpt2_checkpoint_path.name pytorch_dump_folder_path = Path(pytorch_dump_folder_path) pytorch_dump_folder_path.mkdir(exist_ok=True) # Construct model if gpt2_config_file == "": #This doesn't seem to work. We will use the hparams.json file that seems to be included in #config = GPT2Config() gpt2_config_file = gpt2_checkpoint_path / 'hparams.json' config = GPT2Config.from_json_file(gpt2_config_file) model = GPT2Model(config) # Load weights from numpy load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) if not full: model.half() # Save pytorch-model pytorch_weights_dump_path = pytorch_dump_folder_path / WEIGHTS_NAME pytorch_config_dump_path = pytorch_dump_folder_path / CONFIG_NAME print("Save PyTorch model to {}".format(str(pytorch_weights_dump_path))) torch.save(model.state_dict(), pytorch_weights_dump_path) print("Save configuration file to: " + str(pytorch_config_dump_path)) with pytorch_config_dump_path.open("w", encoding="utf-8") as f: f.write(config.to_json_string()) copyfile(gpt2_checkpoint_path / 'vocab.bpe', pytorch_dump_folder_path / 'merges.txt') copyfile(gpt2_checkpoint_path / 'encoder.json', pytorch_dump_folder_path / 'vocab.json')
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): # Construct model if gpt2_config_file == "": config = GPT2Config() else: config = GPT2Config.from_json_file(gpt2_config_file) model = GPT2Model(config) # Load weights from numpy load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) # Save pytorch-model pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) torch.save(model.state_dict(), pytorch_weights_dump_path) print("Save configuration file to {}".format(pytorch_config_dump_path)) with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: f.write(config.to_json_string())