def load_ernie_model(static_model_path):

    model = ErnieModel.from_pretrained('ernie-1.0')
    program_state = static.load_program_state(static_model_path)
    ret_dict = static_params_to_dygraph(model, program_state)

    print('转换前的参数:')
    print(model.embeddings.word_embeddings.weight)
    model.load_dict(ret_dict)
    print('转换后的参数:')
    print(model.embeddings.word_embeddings.weight)
    model.save_pretrained("./ernie_checkpoint")
Exemplo n.º 2
0
def load_dygraph_pretrain(model, path=None, load_static_weights=False):
    if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
        raise ValueError("Model pretrain path {} does not "
                         "exists.".format(path))
    if load_static_weights:
        pre_state_dict = load_program_state(path)
        param_state_dict = {}
        model_dict = model.state_dict()
        for key in model_dict.keys():
            weight_name = model_dict[key].name
            if weight_name in pre_state_dict.keys():
                logger.info('Load weight: {}, shape: {}'.format(
                    weight_name, pre_state_dict[weight_name].shape))
                param_state_dict[key] = pre_state_dict[weight_name]
            else:
                param_state_dict[key] = model_dict[key]
        model.set_dict(param_state_dict)
        return

    param_state_dict = paddle.load(path + ".pdparams")
    model.set_dict(param_state_dict)
    return
Exemplo n.º 3
0
def load_dygraph_pretrain(model, path=None, load_static_weights=False):
    if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
        raise ValueError("Model pretrain path {} does not "
                         "exists.".format(path))
    if load_static_weights:
        pre_state_dict = load_program_state(path)
        param_state_dict = {}
        model_dict = model.state_dict()
        for key in model_dict.keys():
            weight_name = model_dict[key].name
            if weight_name in pre_state_dict.keys():
                logger.info('Load weight: {}, shape: {}'.format(
                    weight_name, pre_state_dict[weight_name].shape))
                param_state_dict[key] = pre_state_dict[weight_name]
            else:
                param_state_dict[key] = model_dict[key]
        model.set_dict(param_state_dict)
        return

    param_state_dict = paddle.load(path + ".pdparams")
    state = fluid.io.load_program_state(path + ".pdparams")
    model.set_dict(param_state_dict)
    print("parameters of paddle", type(state), len(state))  # dict 1628

    filename = open('./output/hrnet_parameter.txt', 'w')
    i = 0
    for key, value in state.items():
        filename.write("\n")
        if type(value) is np.ndarray:
            filename.write("key{}-".format(i) + key +
                           "  :({})".format(value.shape))
        else:
            filename.write("key{}-".format(i) + key + "  :" + str(value))
        filename.write("\n")
        i += 1
    filename.close()
    # print(state.keys())
    return state
import paddle
import paddle.static as static

paddle.enable_static()

x = static.data(name="x", shape=[10, 10], dtype='float32')
y = static.nn.fc(x, 10)
z = static.nn.fc(y, 10)

place = paddle.CPUPlace()
exe = static.Executor(place)
exe.run(static.default_startup_program())
prog = static.default_main_program()

static.save(prog, "./temp")
program_state = static.load_program_state("./temp")

static.set_program_state(prog, program_state)