def init_multi_net_def(model_file): mace_check(os.path.isfile(model_file), "Input graph file '" + model_file + "' does not exist!") multi_net_def = mace_pb2.MultiNetDef() with open(model_file, "rb") as f: multi_net_def.ParseFromString(f.read()) return multi_net_def
def convert(conf, output, enable_micro=False): for model_name, model_conf in conf["models"].items(): model_output = output + "/" + model_name + "/model" org_model_dir = output + "/" + model_name + "/org_model" util.mkdir_p(model_output) util.mkdir_p(org_model_dir) model_conf = normalize_model_config(model_conf, model_output, org_model_dir) conf["models"][model_name] = model_conf net_confs = model_conf[ModelKeys.subgraphs] model = mace_pb2.MultiNetDef() add_input_output_tensor(model, model_conf) model_params = [] for net_name, net_conf in net_confs.items(): if "quantize_stat" in conf: net_conf["quantize_stat"] = conf["quantize_stat"] net_def_with_Data = convert_net(net_name, net_conf, enable_micro) try: visualizer = visualize_model.ModelVisualizer( net_name, net_def_with_Data, model_output) visualizer.save_html() except: # noqa print("Failed to visualize graph:", sys.exc_info()) net_def, params = merge_params(net_def_with_Data, net_conf[ModelKeys.data_type]) if enable_micro: convert_micro( model_name, net_confs, net_def, params, model_output, ) net_def.data_offset = len(model_params) net_def.data_size = len(params) model.net_def.extend([net_def]) model_params.extend(params) # store model and weight to files output_model_file = model_output + "/" + model_name + ".pb" output_params_file = model_output + "/" + model_name + ".data" with open(output_model_file, "wb") as f: f.write(model.SerializeToString()) with open(output_params_file, "wb") as f: f.write(bytearray(model_params)) with open(output_model_file + "_txt", "w") as f: f.write(str(model))
def encrypt(model_name, model_file, params_file, output, is_obfuscate=False, gencode_model=False, gencode_params=False): model_checksum = util.file_checksum(model_file) params_checksum = util.file_checksum(params_file) with open(model_file, "rb") as model_file: with open(params_file, "rb") as params_file: model = mace_pb2.MultiNetDef() model.ParseFromString(model_file.read()) params = bytearray(params_file.read()) if is_obfuscate: for net_def in model.net_def: obfuscate_name(net_def) save_model_to_file(model_name, model, params, output) if gencode_model: save_model_to_code(model_name, model, params, model_checksum, params_checksum, output + "/code/", gencode_params)