コード例 #1
0
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, full,
                                       gpt2_config_file,
                                       pytorch_dump_folder_path):
    #putting requirements here so users can see usage info before it errors out on missing modules
    from io import open
    from shutil import copyfile
    import logging
    logging.basicConfig(level=logging.INFO)
    from pathlib import Path
    import torch
    #WEIGHTS_NAME = "pytorch_model.bin"
    #CONFIG_NAME = "config.json"
    from transformers import (
        CONFIG_NAME,
        WEIGHTS_NAME,
        GPT2Config,
        GPT2Model,
        load_tf_weights_in_gpt2,
    )
    gpt2_checkpoint_path = Path(gpt2_checkpoint_path)
    print(gpt2_checkpoint_path.name)

    if pytorch_dump_folder_path == '':
        prefix = '32BIT-' if full else '16BIT-'
        pytorch_dump_folder_path = 'pytorch-' + prefix + gpt2_checkpoint_path.name
    pytorch_dump_folder_path = Path(pytorch_dump_folder_path)

    pytorch_dump_folder_path.mkdir(exist_ok=True)

    # Construct model
    if gpt2_config_file == "":
        #This doesn't seem to work. We will use the hparams.json file that seems to be included in
        #config = GPT2Config()
        gpt2_config_file = gpt2_checkpoint_path / 'hparams.json'

    config = GPT2Config.from_json_file(gpt2_config_file)
    model = GPT2Model(config)

    # Load weights from numpy
    load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)
    if not full:
        model.half()

    # Save pytorch-model
    pytorch_weights_dump_path = pytorch_dump_folder_path / WEIGHTS_NAME
    pytorch_config_dump_path = pytorch_dump_folder_path / CONFIG_NAME
    print("Save PyTorch model to {}".format(str(pytorch_weights_dump_path)))

    torch.save(model.state_dict(), pytorch_weights_dump_path)

    print("Save configuration file to: " + str(pytorch_config_dump_path))
    with pytorch_config_dump_path.open("w", encoding="utf-8") as f:
        f.write(config.to_json_string())

    copyfile(gpt2_checkpoint_path / 'vocab.bpe',
             pytorch_dump_folder_path / 'merges.txt')
    copyfile(gpt2_checkpoint_path / 'encoder.json',
             pytorch_dump_folder_path / 'vocab.json')
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
    # Construct model
    if gpt2_config_file == "":
        config = GPT2Config()
    else:
        config = GPT2Config.from_json_file(gpt2_config_file)
    model = GPT2Model(config)

    # Load weights from numpy
    load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)

    # Save pytorch-model
    pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
    pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
    print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
    torch.save(model.state_dict(), pytorch_weights_dump_path)
    print("Save configuration file to {}".format(pytorch_config_dump_path))
    with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
        f.write(config.to_json_string())