def write_desc_file(filename, op_descs, var_descs): """ write desc program to file """ prog_desc = framework_pb2.ProgramDesc() block_desc = prog_desc.blocks.add() block_desc.idx = 0 block_desc.parent_idx = -1 block_desc.ops.extend(op_descs) block_desc.vars.extend(var_descs) # add feed-fetch on vars feed_var_desc = block_desc.vars.add() feed_var_desc.name = 'feed' feed_var_desc.type.type = framework_pb2.VarType.FEED_MINIBATCH feed_var_desc.persistable = True fetch_var_desc = block_desc.vars.add() fetch_var_desc.name = 'fetch' fetch_var_desc.type.type = framework_pb2.VarType.FETCH_LIST fetch_var_desc.persistable = True fp = open(filename, 'wb') fp.write(prog_desc.SerializeToString()) fp.close() logger.debug('saved descs to %s', filename)
def load_program_text(model_filename): """load program from human-readable text file""" with open(model_filename, "r") as f: program_desc_text = f.read() prog_desc = framework_pb2.ProgramDesc() text_format.Merge(program_desc_text, prog_desc) return Program.parse_from_string(prog_desc.SerializeToString())
import os import sys import paddle.fluid.proto.framework_pb2 as framework_pb2 infile = open("__model__", "rb") content = infile.read() infile.close() prog = framework_pb2.ProgramDesc() prog.ParseFromString(content) vars = prog.blocks[0].vars ops = prog.blocks[0].ops def modify_reshape2_transpose2(): for op in ops: if op.type == "transpose2": axis = [1, 0, 2, 3] for attr in op.attrs: if attr.name == "axis": if len(attr.ints) >= 4: attr.ints.pop() attr.ints.pop() attr.ints.pop() attr.ints.pop() attr.ints.pop() attr.ints.extend(axis) if op.type == "reshape2": shape = []
padding=(1, 1), dilation=(1, 1), groups=1, bias_attr=False) cpu = paddle.static.cpu_places(1) exe = paddle.static.Executor(cpu[0]) exe.run(startup_program) inp_dict = {'x': inp_blob} var = [test_layer] res_paddle = exe.run(paddle.static.default_main_program(), fetch_list=var, feed=inp_dict) paddle.static.save_inference_model(os.path.join(sys.argv[1], "lower_version/", "lower_version"), [x], [test_layer], exe, program=main_program) fw_model = framework_pb2.ProgramDesc() with open(os.path.join(sys.argv[1], "lower_version", "lower_version.pdmodel"), mode='rb') as file: fw_model.ParseFromString(file.read()) fw_model.version.version = 1800000 print(fw_model.version.version) with open(os.path.join(sys.argv[1], "lower_version", "lower_version.pdmodel"), "wb") as f: f.write(fw_model.SerializeToString())