Example #1
0
def init_multi_net_def(model_file):
    mace_check(os.path.isfile(model_file),
               "Input graph file '" + model_file + "' does not exist!")
    multi_net_def = mace_pb2.MultiNetDef()
    with open(model_file, "rb") as f:
        multi_net_def.ParseFromString(f.read())
    return multi_net_def
Example #2
0
def convert(conf, output, enable_micro=False):
    for model_name, model_conf in conf["models"].items():
        model_output = output + "/" + model_name + "/model"
        org_model_dir = output + "/" + model_name + "/org_model"
        util.mkdir_p(model_output)
        util.mkdir_p(org_model_dir)

        model_conf = normalize_model_config(model_conf, model_output,
                                            org_model_dir)
        conf["models"][model_name] = model_conf
        net_confs = model_conf[ModelKeys.subgraphs]

        model = mace_pb2.MultiNetDef()
        add_input_output_tensor(model, model_conf)

        model_params = []
        for net_name, net_conf in net_confs.items():
            if "quantize_stat" in conf:
                net_conf["quantize_stat"] = conf["quantize_stat"]
            net_def_with_Data = convert_net(net_name, net_conf, enable_micro)
            try:
                visualizer = visualize_model.ModelVisualizer(
                    net_name, net_def_with_Data, model_output)
                visualizer.save_html()
            except:  # noqa
                print("Failed to visualize graph:", sys.exc_info())
            net_def, params = merge_params(net_def_with_Data,
                                           net_conf[ModelKeys.data_type])
            if enable_micro:
                convert_micro(
                    model_name,
                    net_confs,
                    net_def,
                    params,
                    model_output,
                )

            net_def.data_offset = len(model_params)
            net_def.data_size = len(params)
            model.net_def.extend([net_def])
            model_params.extend(params)
        # store model and weight to files
        output_model_file = model_output + "/" + model_name + ".pb"
        output_params_file = model_output + "/" + model_name + ".data"
        with open(output_model_file, "wb") as f:
            f.write(model.SerializeToString())
        with open(output_params_file, "wb") as f:
            f.write(bytearray(model_params))
        with open(output_model_file + "_txt", "w") as f:
            f.write(str(model))
Example #3
0
def encrypt(model_name,
            model_file,
            params_file,
            output,
            is_obfuscate=False,
            gencode_model=False,
            gencode_params=False):
    model_checksum = util.file_checksum(model_file)
    params_checksum = util.file_checksum(params_file)

    with open(model_file, "rb") as model_file:
        with open(params_file, "rb") as params_file:
            model = mace_pb2.MultiNetDef()
            model.ParseFromString(model_file.read())
            params = bytearray(params_file.read())

            if is_obfuscate:
                for net_def in model.net_def:
                    obfuscate_name(net_def)
            save_model_to_file(model_name, model, params, output)
            if gencode_model:
                save_model_to_code(model_name, model, params, model_checksum,
                                   params_checksum, output + "/code/",
                                   gencode_params)