예제 #1
0
파일: writer.py 프로젝트: shanyi15/X2Paddle
    def write_desc_file(filename, op_descs, var_descs):
        """
        write desc program to file
        """

        prog_desc = framework_pb2.ProgramDesc()
        block_desc = prog_desc.blocks.add()
        block_desc.idx = 0
        block_desc.parent_idx = -1
        block_desc.ops.extend(op_descs)
        block_desc.vars.extend(var_descs)

        # add feed-fetch on vars
        feed_var_desc = block_desc.vars.add()
        feed_var_desc.name = 'feed'
        feed_var_desc.type.type = framework_pb2.VarType.FEED_MINIBATCH
        feed_var_desc.persistable = True
        fetch_var_desc = block_desc.vars.add()
        fetch_var_desc.name = 'fetch'
        fetch_var_desc.type.type = framework_pb2.VarType.FETCH_LIST
        fetch_var_desc.persistable = True

        fp = open(filename, 'wb')
        fp.write(prog_desc.SerializeToString())
        fp.close()
        logger.debug('saved descs to %s', filename)
예제 #2
0
def load_program_text(model_filename):
    """load program from human-readable text file"""
    with open(model_filename, "r") as f:
        program_desc_text = f.read()

    prog_desc = framework_pb2.ProgramDesc()
    text_format.Merge(program_desc_text, prog_desc)
    return Program.parse_from_string(prog_desc.SerializeToString())
예제 #3
0
import os
import sys
import paddle.fluid.proto.framework_pb2 as framework_pb2

infile = open("__model__", "rb")
content = infile.read()
infile.close()

prog = framework_pb2.ProgramDesc()
prog.ParseFromString(content)

vars = prog.blocks[0].vars
ops = prog.blocks[0].ops


def modify_reshape2_transpose2():
    for op in ops:
        if op.type == "transpose2":
            axis = [1, 0, 2, 3]
        for attr in op.attrs:
            if attr.name == "axis":
                if len(attr.ints) >= 4:
                    attr.ints.pop()
                    attr.ints.pop()
                    attr.ints.pop()
                    attr.ints.pop()
                    attr.ints.pop()
                    attr.ints.extend(axis)

        if op.type == "reshape2":
            shape = []
예제 #4
0
                                         padding=(1, 1),
                                         dilation=(1, 1),
                                         groups=1,
                                         bias_attr=False)

    cpu = paddle.static.cpu_places(1)
    exe = paddle.static.Executor(cpu[0])
    exe.run(startup_program)
    inp_dict = {'x': inp_blob}
    var = [test_layer]
    res_paddle = exe.run(paddle.static.default_main_program(),
                         fetch_list=var,
                         feed=inp_dict)
    paddle.static.save_inference_model(os.path.join(sys.argv[1],
                                                    "lower_version/",
                                                    "lower_version"), [x],
                                       [test_layer],
                                       exe,
                                       program=main_program)

fw_model = framework_pb2.ProgramDesc()
with open(os.path.join(sys.argv[1], "lower_version", "lower_version.pdmodel"),
          mode='rb') as file:
    fw_model.ParseFromString(file.read())

fw_model.version.version = 1800000
print(fw_model.version.version)
with open(os.path.join(sys.argv[1], "lower_version", "lower_version.pdmodel"),
          "wb") as f:
    f.write(fw_model.SerializeToString())