Ejemplo n.º 1
0
def main(unused_args):
    if not os.path.isfile(FLAGS.model_file):
        six.print_("Input graph file '" + FLAGS.model_file +
                   "' does not exist!",
                   file=sys.stderr)
        sys.exit(-1)

    model_checksum = file_checksum(FLAGS.model_file)
    if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
        six.print_("Model checksum mismatch: %s != %s" %
                   (model_checksum, FLAGS.model_checksum),
                   file=sys.stderr)
        sys.exit(-1)

    weight_checksum = None
    if FLAGS.platform == 'caffe':
        if not os.path.isfile(FLAGS.weight_file):
            six.print_("Input weight file '" + FLAGS.weight_file +
                       "' does not exist!",
                       file=sys.stderr)
            sys.exit(-1)

        weight_checksum = file_checksum(FLAGS.weight_file)
        if FLAGS.weight_checksum != "" and \
                FLAGS.weight_checksum != weight_checksum:
            six.print_("Weight checksum mismatch: %s != %s" %
                       (weight_checksum, FLAGS.weight_checksum),
                       file=sys.stderr)
            sys.exit(-1)

    if FLAGS.platform not in ['tensorflow', 'caffe', 'onnx']:
        six.print_("platform %s is not supported." % FLAGS.platform,
                   file=sys.stderr)
        sys.exit(-1)
    if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', 'hta', 'apu', 'cpu+gpu']:
        six.print_("runtime %s is not supported." % FLAGS.runtime,
                   file=sys.stderr)
        sys.exit(-1)

    option = cvt.ConverterOption()
    if FLAGS.graph_optimize_options:
        option.transformer_option = FLAGS.graph_optimize_options.split(',')
    option.winograd = FLAGS.winograd
    option.quantize = FLAGS.quantize
    option.quantize_large_weights = FLAGS.quantize_large_weights
    option.quantize_range_file = FLAGS.quantize_range_file
    option.change_concat_ranges = FLAGS.change_concat_ranges
    option.cl_mem_type = FLAGS.cl_mem_type
    option.device = device_type_map[FLAGS.runtime]
    option.data_type = parse_data_type(FLAGS.data_type, option.device)

    input_node_names = FLAGS.input_node.split(',')
    input_data_types = FLAGS.input_data_types.split(',')
    input_node_shapes = FLAGS.input_shape.split(':')
    input_node_formats = FLAGS.input_data_formats.split(",")
    if FLAGS.input_range:
        input_node_ranges = FLAGS.input_range.split(':')
    else:
        input_node_ranges = []
    if len(input_node_names) != len(input_node_shapes):
        raise Exception('input node count and shape count do not match.')
    for i in six.moves.range(len(input_node_names)):
        input_node = cvt.NodeInfo()
        input_node.name = input_node_names[i]
        input_node.data_type = data_type_map[input_data_types[i]]
        input_node.data_format = data_format_map[input_node_formats[i]]
        input_node.shape = parse_int_array_from_str(input_node_shapes[i])
        if input_node.data_format == cvt.DataFormat.NCHW and\
                len(input_node.shape) == 4:
            input_node.shape = transpose_shape(input_node.shape, [0, 2, 3, 1])
            input_node.data_format = cvt.DataFormat.NHWC
        if len(input_node_ranges) > i:
            input_node.range = parse_float_array_from_str(input_node_ranges[i])
        option.add_input_node(input_node)

    output_node_names = FLAGS.output_node.split(',')
    output_data_types = FLAGS.output_data_types.split(',')
    output_node_shapes = FLAGS.output_shape.split(':')
    output_node_formats = FLAGS.output_data_formats.split(",")
    if len(output_node_names) != len(output_node_shapes):
        raise Exception('output node count and shape count do not match.')
    for i in six.moves.range(len(output_node_names)):
        output_node = cvt.NodeInfo()
        output_node.name = output_node_names[i]
        output_node.data_type = data_type_map[output_data_types[i]]
        output_node.data_format = data_format_map[output_node_formats[i]]
        output_node.shape = parse_int_array_from_str(output_node_shapes[i])
        if output_node.data_format == cvt.DataFormat.NCHW and\
                len(output_node.shape) == 4:
            output_node.shape = transpose_shape(output_node.shape,
                                                [0, 2, 3, 1])
            output_node.data_format = cvt.DataFormat.NHWC
        option.add_output_node(output_node)

    if FLAGS.check_node != '':
        check_node_names = FLAGS.check_node.split(',')
        check_node_shapes = FLAGS.check_shape.split(':')
        if len(check_node_names) != len(check_node_shapes):
            raise Exception('check node count and shape count do not match.')
        for i in six.moves.range(len(check_node_names)):
            check_node = cvt.NodeInfo()
            check_node.name = check_node_names[i]
            check_node.shape = parse_int_array_from_str(check_node_shapes[i])
            option.add_check_node(check_node)
    else:
        option.check_nodes = option.output_nodes

    option.build()

    print("Transform model to one that can better run on device")
    if FLAGS.platform == 'tensorflow':
        from mace.python.tools.converter_tool import tensorflow_converter
        converter = tensorflow_converter.TensorflowConverter(
            option, FLAGS.model_file)
    elif FLAGS.platform == 'caffe':
        from mace.python.tools.converter_tool import caffe_converter
        converter = caffe_converter.CaffeConverter(option, FLAGS.model_file,
                                                   FLAGS.weight_file)
    elif FLAGS.platform == 'onnx':
        from mace.python.tools.converter_tool import onnx_converter
        converter = onnx_converter.OnnxConverter(option, FLAGS.model_file)
    else:
        six.print_("Mace do not support platorm %s yet." % FLAGS.platform,
                   file=sys.stderr)
        exit(1)

    output_graph_def = converter.run()
    mace_transformer = transformer.Transformer(option, output_graph_def)
    output_graph_def, quantize_activation_info = mace_transformer.run()

    if option.device in [
            cvt.DeviceType.HEXAGON.value, cvt.DeviceType.HTA.value
    ]:
        from mace.python.tools.converter_tool import hexagon_converter
        converter = hexagon_converter.HexagonConverter(
            option, output_graph_def, quantize_activation_info)
        output_graph_def = converter.run()
    elif FLAGS.runtime == 'apu':
        if FLAGS.platform != 'tensorflow':
            raise Exception('apu only support model from tensorflow')
        from mace.python.tools.converter_tool import apu_converter
        converter = apu_converter.ApuConverter(option, output_graph_def,
                                               quantize_activation_info)
        output_graph_def = converter.run()

    try:
        visualizer = visualize_model.ModelVisualizer(FLAGS.model_tag,
                                                     output_graph_def)
        visualizer.save_html()
    except:  # noqa
        print("Failed to visualize model:", sys.exc_info()[0])

    model_saver.save_model(option, output_graph_def, model_checksum,
                           weight_checksum, FLAGS.template_dir,
                           FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output_dir,
                           FLAGS.embed_model_data, FLAGS.winograd,
                           FLAGS.model_graph_format)
Ejemplo n.º 2
0
def main(unused_args):
    if not os.path.isfile(FLAGS.model_file):
        six.print_("Input graph file '" + FLAGS.model_file +
                   "' does not exist!",
                   file=sys.stderr)
        sys.exit(-1)

    model_checksum = file_checksum(FLAGS.model_file)
    if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
        six.print_("Model checksum mismatch: %s != %s" %
                   (model_checksum, FLAGS.model_checksum),
                   file=sys.stderr)
        sys.exit(-1)

    weight_checksum = None
    if FLAGS.platform == 'caffe':
        if not os.path.isfile(FLAGS.weight_file):
            six.print_("Input weight file '" + FLAGS.weight_file +
                       "' does not exist!",
                       file=sys.stderr)
            sys.exit(-1)

        weight_checksum = file_checksum(FLAGS.weight_file)
        if FLAGS.weight_checksum != "" and \
                FLAGS.weight_checksum != weight_checksum:
            six.print_("Weight checksum mismatch: %s != %s" %
                       (weight_checksum, FLAGS.weight_checksum),
                       file=sys.stderr)
            sys.exit(-1)

    if FLAGS.platform not in ['tensorflow', 'caffe']:
        six.print_("platform %s is not supported." % FLAGS.platform,
                   file=sys.stderr)
        sys.exit(-1)
    if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', 'cpu+gpu']:
        six.print_("runtime %s is not supported." % FLAGS.runtime,
                   file=sys.stderr)
        sys.exit(-1)

    option = cvt.ConverterOption()
    if FLAGS.graph_optimize_options:
        option.transformer_option = FLAGS.graph_optimize_options.split(',')
    option.winograd = FLAGS.winograd
    option.quantize = FLAGS.quantize
    option.quantize_range_file = FLAGS.quantize_range_file
    option.change_concat_ranges = FLAGS.change_concat_ranges
    option.cl_mem_type = FLAGS.cl_mem_type

    input_node_names = FLAGS.input_node.split(',')
    input_node_shapes = FLAGS.input_shape.split(':')
    input_node_formats = FLAGS.input_data_formats.split(",")
    if FLAGS.input_range:
        input_node_ranges = FLAGS.input_range.split(':')
    else:
        input_node_ranges = []
    if len(input_node_names) != len(input_node_shapes):
        raise Exception('input node count and shape count do not match.')
    for i in six.moves.range(len(input_node_names)):
        input_node = cvt.NodeInfo()
        input_node.name = input_node_names[i]
        if len(input_node_formats) == 1:
            input_node.data_format = data_format_map[input_node_formats[0]]
        else:
            input_node.data_format = data_format_map[input_node_formats[i]]
        input_node.shape = parse_int_array_from_str(input_node_shapes[i])
        if len(input_node_ranges) > i:
            input_node.range = parse_float_array_from_str(input_node_ranges[i])
        option.add_input_node(input_node)

    output_node_names = FLAGS.output_node.split(',')
    output_node_shapes = FLAGS.output_shape.split(':')
    output_node_formats = FLAGS.output_data_formats.split(",")
    if len(output_node_names) != len(output_node_shapes):
        raise Exception('output node count and shape count do not match.')
    for i in six.moves.range(len(output_node_names)):
        output_node = cvt.NodeInfo()
        output_node.name = output_node_names[i]
        if len(output_node_formats) == 1:
            output_node.data_format = data_format_map[output_node_formats[0]]
        else:
            output_node.data_format = data_format_map[output_node_formats[i]]
        output_node.shape = parse_int_array_from_str(output_node_shapes[i])
        option.add_output_node(output_node)

    if FLAGS.check_node != '':
        check_node_names = FLAGS.check_node.split(',')
        check_node_shapes = FLAGS.check_shape.split(':')
        if len(check_node_names) != len(check_node_shapes):
            raise Exception('check node count and shape count do not match.')
        for i in six.moves.range(len(check_node_names)):
            check_node = cvt.NodeInfo()
            check_node.name = check_node_names[i]
            check_node.shape = parse_int_array_from_str(check_node_shapes[i])
            option.add_check_node(check_node)

    option.build()

    print("Transform model to one that can better run on device")
    if FLAGS.runtime == 'dsp' and not option.quantize:
        mace_check(FLAGS.platform == 'tensorflow',
                   'DSP only supports tensorflow')
        from mace.python.tools.converter_tool import tf_dsp_converter
        converter = tf_dsp_converter.TensorflowDspConverter(
            option, FLAGS.model_file)
        output_graph_def = converter.run()
    else:
        if FLAGS.platform == 'tensorflow':
            from mace.python.tools.converter_tool import tensorflow_converter
            converter = tensorflow_converter.TensorflowConverter(
                option, FLAGS.model_file)
        elif FLAGS.platform == 'caffe':
            from mace.python.tools.converter_tool import caffe_converter
            converter = caffe_converter.CaffeConverter(option,
                                                       FLAGS.model_file,
                                                       FLAGS.weight_file)
        else:
            six.print_("Mace do not support platorm %s yet." % FLAGS.platform,
                       file=sys.stderr)
            exit(1)

        output_graph_def = converter.run()

        option.device = device_type_map[FLAGS.runtime]
        option.data_type = parse_data_type(FLAGS.data_type, option.device)
        mace_transformer = transformer.Transformer(option, output_graph_def)
        output_graph_def, quantize_activation_info = mace_transformer.run()

        if FLAGS.runtime == 'dsp':
            from mace.python.tools.converter_tool import hexagon_converter
            converter = hexagon_converter.HexagonConverter(
                option, output_graph_def, quantize_activation_info)
            output_graph_def = converter.run()

    model_saver.save_model(option, output_graph_def, model_checksum,
                           weight_checksum, FLAGS.template_dir,
                           FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output_dir,
                           FLAGS.embed_model_data, FLAGS.winograd,
                           FLAGS.model_graph_format)
Ejemplo n.º 3
0
def main(unused_args):
    if not os.path.isfile(FLAGS.model_file):
        print("Input graph file '" + FLAGS.model_file + "' does not exist!")
        sys.exit(-1)

    model_checksum = file_checksum(FLAGS.model_file)
    if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
        print("Model checksum mismatch: %s != %s" %
              (model_checksum, FLAGS.model_checksum))
        sys.exit(-1)

    weight_checksum = None
    if FLAGS.platform == 'caffe':
        if not os.path.isfile(FLAGS.weight_file):
            print("Input weight file '" + FLAGS.weight_file +
                  "' does not exist!")
            sys.exit(-1)

        weight_checksum = file_checksum(FLAGS.weight_file)
        if FLAGS.weight_checksum != "" and \
                FLAGS.weight_checksum != weight_checksum:
            print("Weight checksum mismatch: %s != %s" %
                  (weight_checksum, FLAGS.weight_checksum))
            sys.exit(-1)

    if FLAGS.platform not in ['tensorflow', 'caffe']:
        print("platform %s is not supported." % FLAGS.platform)
        sys.exit(-1)
    if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', 'cpu+gpu']:
        print("runtime %s is not supported." % FLAGS.runtime)
        sys.exit(-1)

    option = cvt.ConverterOption()
    if FLAGS.graph_optimize_options:
        option.transformer_option = FLAGS.graph_optimize_options.split(',')
    option.winograd = FLAGS.winograd
    option.quantize = FLAGS.quantize
    option.quantize_range_file = FLAGS.quantize_range_file
    option.change_concat_ranges = FLAGS.change_concat_ranges
    option.cl_mem_type = FLAGS.cl_mem_type

    input_node_names = FLAGS.input_node.split(',')
    input_node_shapes = FLAGS.input_shape.split(':')
    if FLAGS.input_range:
        input_node_ranges = FLAGS.input_range.split(':')
    else:
        input_node_ranges = []
    if len(input_node_names) != len(input_node_shapes):
        raise Exception('input node count and shape count do not match.')
    for i in xrange(len(input_node_names)):
        input_node = cvt.NodeInfo()
        input_node.name = input_node_names[i]
        input_node.shape = parse_int_array_from_str(input_node_shapes[i])
        if len(input_node_ranges) > i:
            input_node.range = parse_float_array_from_str(input_node_ranges[i])
        option.add_input_node(input_node)

    output_node_names = FLAGS.output_node.split(',')
    for i in xrange(len(output_node_names)):
        output_node = cvt.NodeInfo()
        output_node.name = output_node_names[i]
        option.add_output_node(output_node)

    option.build()

    print("Transform model to one that can better run on device")
    if FLAGS.runtime == 'dsp':
        mace_check(FLAGS.platform == 'tensorflow',
                   'DSP only supports tensorflow')
        from mace.python.tools.converter_tool import tf_dsp_converter
        converter = tf_dsp_converter.TensorflowDspConverter(
            option, FLAGS.model_file)
        output_graph_def = converter.run()
    else:
        if FLAGS.platform == 'tensorflow':
            from mace.python.tools.converter_tool import tensorflow_converter
            converter = tensorflow_converter.TensorflowConverter(
                option, FLAGS.model_file)
        elif FLAGS.platform == 'caffe':
            from mace.python.tools.converter_tool import caffe_converter
            converter = caffe_converter.CaffeConverter(option,
                                                       FLAGS.model_file,
                                                       FLAGS.weight_file)
        else:
            print("Mace do not support platorm %s yet." & FLAGS.platform)
            exit(1)

        output_graph_def = converter.run()

        if FLAGS.runtime == 'cpu+gpu':
            cpu_graph_def = copy.deepcopy(output_graph_def)

            option.device = cvt.DeviceType.GPU.value
            option.data_type = parse_data_type(FLAGS.data_type,
                                               cvt.DeviceType.GPU.value)
            mace_gpu_transformer = transformer.Transformer(
                option, output_graph_def)
            output_graph_def = mace_gpu_transformer.run()
            print "start optimize gpu memory."
            memory_optimizer.optimize_gpu_memory(output_graph_def)
            print "GPU memory optimization done."

            option.device = cvt.DeviceType.CPU.value
            option.data_type = parse_data_type(FLAGS.data_type,
                                               cvt.DeviceType.CPU.value)
            option.disable_transpose_filters()
            mace_cpu_transformer = transformer.Transformer(
                option, cpu_graph_def)
            cpu_graph_def = mace_cpu_transformer.run()
            print "start optimize cpu memory."
            memory_optimizer.optimize_cpu_memory(cpu_graph_def)
            print "CPU memory optimization done."

            print "Merge cpu and gpu ops together"
            output_graph_def.op.extend(cpu_graph_def.op)
            output_graph_def.mem_arena.mem_block.extend(
                cpu_graph_def.mem_arena.mem_block)
            output_graph_arg_names = set()
            for arg in output_graph_def.arg:
                output_graph_arg_names.add(arg.name)

            for arg in cpu_graph_def.arg:
                if arg.name not in output_graph_arg_names:
                    output_graph_def.arg.extend(arg)
            print "Merge done"
        else:
            option.device = device_type_map[FLAGS.runtime]
            option.data_type = parse_data_type(FLAGS.data_type, option.device)
            mace_transformer = transformer.Transformer(option,
                                                       output_graph_def)
            output_graph_def = mace_transformer.run()

            print "start optimize memory."
            if FLAGS.runtime == 'gpu':
                memory_optimizer.optimize_gpu_memory(output_graph_def)
            elif FLAGS.runtime == 'cpu':
                memory_optimizer.optimize_cpu_memory(output_graph_def)
            else:
                mace_check(False, "runtime only support [gpu|cpu|dsp]")

            print "Memory optimization done."

    model_saver.save_model(output_graph_def, model_checksum, weight_checksum,
                           FLAGS.template_dir, FLAGS.obfuscate,
                           FLAGS.model_tag, FLAGS.output_dir, FLAGS.runtime,
                           FLAGS.embed_model_data, FLAGS.winograd,
                           FLAGS.data_type, FLAGS.model_graph_format)
Ejemplo n.º 4
0
def mace_convert_model(platform, model_file, model_checksum_in, weight_file,
                       weight_checksum_in, runtime, data_type, input_node,
                       input_shape, output_node, dsp_mode,
                       graph_optimize_options, winograd, template_dir,
                       obfuscate, model_tag, output_dir, embed_model_data,
                       model_graph_format):
    if not os.path.isfile(model_file):
        print("Input graph file '" + model_file + "' does not exist!")
        sys.exit(-1)

    model_checksum = file_checksum(model_file)
    if model_checksum_in is not None and model_checksum_in != model_checksum:
        print("Model checksum mismatch: %s != %s" %
              (model_checksum, model_checksum_in))
        sys.exit(-1)

    weight_checksum = None
    if platform == 'caffe':
        if not os.path.isfile(weight_file):
            print("Input weight file '" + weight_file + "' does not exist!")
            sys.exit(-1)

        weight_checksum = file_checksum(weight_file)
        if weight_checksum_in is not None and \
                        weight_checksum_in != weight_checksum:
            print("Weight checksum mismatch: %s != %s" %
                  (weight_checksum, weight_checksum_in))
            sys.exit(-1)

    if platform not in ['caffe']:
        print("platform %s is not supported." % platform)
        sys.exit(-1)
    if runtime not in ['cpu', 'gpu', 'dsp', 'cpu+gpu']:
        print("runtime %s is not supported." % runtime)
        sys.exit(-1)

    if graph_optimize_options:
        option = cvt.ConverterOption(graph_optimize_options.split(','))
    else:
        option = cvt.ConverterOption()
    option.winograd = winograd

    input_node_names = input_node.split(',')
    input_node_shapes = input_shape.split(':')
    if len(input_node_names) != len(input_node_shapes):
        raise Exception('input node count and shape count do not match.')
    for i in range(len(input_node_names)):
        input_node = cvt.NodeInfo()
        input_node.name = input_node_names[i]
        input_node.shape = parse_int_array_from_str(input_node_shapes[i])
        option.add_input_node(input_node)

    output_node_names = output_node.split(',')
    for i in range(len(output_node_names)):
        output_node = cvt.NodeInfo()
        output_node.name = output_node_names[i]
        option.add_output_node(output_node)

    print("Transform model to one that can better run on device")
    if runtime == 'dsp':
        mace_check(platform == 'tensorflow', 'DSP only supports tensorflow')
        from mace.python.tools.converter_tool import tf_dsp_converter
        converter = tf_dsp_converter.TensorflowDspConverter(option, model_file)
        output_graph_def = converter.run()
    else:
        if platform == 'tensorflow':
            from mace.python.tools.converter_tool import tensorflow_converter
            converter = tensorflow_converter.TensorflowConverter(
                option, model_file)
        elif platform == 'caffe':
            from mace.python.tools.converter_tool import caffe_converter
            converter = caffe_converter.CaffeConverter(option, model_file,
                                                       weight_file)
        else:
            print("Mace do not support platorm %s yet." % platform)
            exit(1)

        output_graph_def = converter.run()

        print("Transform model to one that can better run on device")
        if runtime == 'cpu+gpu':
            cpu_graph_def = copy.deepcopy(output_graph_def)

            option.device = cvt.DeviceType.GPU.value
            option.data_type = parse_data_type(data_type,
                                               cvt.DeviceType.GPU.value)
            mace_gpu_transformer = transformer.Transformer(
                option, output_graph_def)
            output_graph_def = mace_gpu_transformer.run()
            print("start optimize gpu memory.")
            memory_optimizer.optimize_gpu_memory(output_graph_def)
            print("GPU memory optimization done.")

            option.device = cvt.DeviceType.CPU.value
            option.data_type = parse_data_type(data_type,
                                               cvt.DeviceType.CPU.value)
            option.disable_transpose_filters()
            mace_cpu_transformer = transformer.Transformer(
                option, cpu_graph_def)
            cpu_graph_def = mace_cpu_transformer.run()
            print("start optimize cpu memory.")
            memory_optimizer.optimize_cpu_memory(cpu_graph_def)
            print("CPU memory optimization done.")

            print("Merge cpu and gpu ops together")
            output_graph_def.op.extend(cpu_graph_def.op)
            output_graph_def.mem_arena.mem_block.extend(
                cpu_graph_def.mem_arena.mem_block)
            print("Merge done")
        else:
            option.device = device_type_map[runtime]
            option.data_type = parse_data_type(data_type, option.device)
            mace_transformer = transformer.Transformer(option,
                                                       output_graph_def)
            output_graph_def = mace_transformer.run()

            print("start optimize memory.")
            if runtime == 'gpu':
                memory_optimizer.optimize_gpu_memory(output_graph_def)
            elif runtime == 'cpu':
                memory_optimizer.optimize_cpu_memory(output_graph_def)
            else:
                mace_check(False, "runtime only support [gpu|cpu|dsp]")

            print("Memory optimization done.")

    model_saver.save_model(output_graph_def, model_checksum, weight_checksum,
                           template_dir, obfuscate, model_tag, output_dir,
                           runtime, embed_model_data, winograd, data_type,
                           model_graph_format)