def convert_model(base_model, path, new_path):
    model = T5ForConditionalGeneration(T5Config.from_pretrained(base_model))
    print("loading weights...")
    load_tf_weights_in_t5(model, None, path)
    model.eval()
    print("saving HF weights...")
    model.save_pretrained(new_path)
Esempio n. 2
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = T5Config.from_json_file(config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = T5Model(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_t5(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
Esempio n. 3
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file,
                                     pytorch_dump_path):
    # Initialise PyTorch model
    config = T5Config.from_json_file(config_file)
    print(f"Building PyTorch model from configuration: {config}")
    model = T5ForConditionalGeneration(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_t5(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print(f"Save PyTorch model to {pytorch_dump_path}")
    model.save_pretrained(pytorch_dump_path)
Esempio n. 4
0
from transformers import MT5Config, MT5ForConditionalGeneration, load_tf_weights_in_t5
import torch

config = MT5Config.from_pretrained('config.json')
model = MT5ForConditionalGeneration(config)

ckpt = 'D:\\BaiduNetdiskDownload\\chinese_t5_pegasus_base\\chinese_t5_pegasus_base\\model.ckpt'

model = load_tf_weights_in_t5(model, config, ckpt)

torch.save(model.state_dict(), 'pytorch_model.bin')