def convert_model(base_model, path, new_path): model = T5ForConditionalGeneration(T5Config.from_pretrained(base_model)) print("loading weights...") load_tf_weights_in_t5(model, None, path) model.eval() print("saving HF weights...") model.save_pretrained(new_path)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): # Initialise PyTorch model config = T5Config.from_json_file(config_file) print("Building PyTorch model from configuration: {}".format(str(config))) model = T5Model(config) # Load weights from tf checkpoint load_tf_weights_in_t5(model, config, tf_checkpoint_path) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) torch.save(model.state_dict(), pytorch_dump_path)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): # Initialise PyTorch model config = T5Config.from_json_file(config_file) print(f"Building PyTorch model from configuration: {config}") model = T5ForConditionalGeneration(config) # Load weights from tf checkpoint load_tf_weights_in_t5(model, config, tf_checkpoint_path) # Save pytorch-model print(f"Save PyTorch model to {pytorch_dump_path}") model.save_pretrained(pytorch_dump_path)
from transformers import MT5Config, MT5ForConditionalGeneration, load_tf_weights_in_t5 import torch config = MT5Config.from_pretrained('config.json') model = MT5ForConditionalGeneration(config) ckpt = 'D:\\BaiduNetdiskDownload\\chinese_t5_pegasus_base\\chinese_t5_pegasus_base\\model.ckpt' model = load_tf_weights_in_t5(model, config, ckpt) torch.save(model.state_dict(), 'pytorch_model.bin')