Exemple #1
0
def main(unused_args):
    if not os.path.isfile(FLAGS.model_file):
        six.print_("Input graph file '" + FLAGS.model_file +
                   "' does not exist!",
                   file=sys.stderr)
        sys.exit(-1)

    model_checksum = file_checksum(FLAGS.model_file)
    if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
        six.print_("Model checksum mismatch: %s != %s" %
                   (model_checksum, FLAGS.model_checksum),
                   file=sys.stderr)
        sys.exit(-1)

    weight_checksum = None
    if FLAGS.platform == 'caffe':
        if not os.path.isfile(FLAGS.weight_file):
            six.print_("Input weight file '" + FLAGS.weight_file +
                       "' does not exist!",
                       file=sys.stderr)
            sys.exit(-1)

        weight_checksum = file_checksum(FLAGS.weight_file)
        if FLAGS.weight_checksum != "" and \
                FLAGS.weight_checksum != weight_checksum:
            six.print_("Weight checksum mismatch: %s != %s" %
                       (weight_checksum, FLAGS.weight_checksum),
                       file=sys.stderr)
            sys.exit(-1)

    if FLAGS.platform not in ['tensorflow', 'caffe']:
        six.print_("platform %s is not supported." % FLAGS.platform,
                   file=sys.stderr)
        sys.exit(-1)
    if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', 'cpu+gpu']:
        six.print_("runtime %s is not supported." % FLAGS.runtime,
                   file=sys.stderr)
        sys.exit(-1)

    option = cvt.ConverterOption()
    if FLAGS.graph_optimize_options:
        option.transformer_option = FLAGS.graph_optimize_options.split(',')
    option.winograd = FLAGS.winograd
    option.quantize = FLAGS.quantize
    option.quantize_range_file = FLAGS.quantize_range_file
    option.change_concat_ranges = FLAGS.change_concat_ranges
    option.cl_mem_type = FLAGS.cl_mem_type

    input_node_names = FLAGS.input_node.split(',')
    input_node_shapes = FLAGS.input_shape.split(':')
    if FLAGS.input_range:
        input_node_ranges = FLAGS.input_range.split(':')
    else:
        input_node_ranges = []
    if len(input_node_names) != len(input_node_shapes):
        raise Exception('input node count and shape count do not match.')
    for i in six.moves.range(len(input_node_names)):
        input_node = cvt.NodeInfo()
        input_node.name = input_node_names[i]
        input_node.shape = parse_int_array_from_str(input_node_shapes[i])
        if len(input_node_ranges) > i:
            input_node.range = parse_float_array_from_str(input_node_ranges[i])
        option.add_input_node(input_node)

    output_node_names = FLAGS.output_node.split(',')
    output_node_shapes = FLAGS.output_shape.split(':')
    if len(output_node_names) != len(output_node_shapes):
        raise Exception('output node count and shape count do not match.')
    for i in six.moves.range(len(output_node_names)):
        output_node = cvt.NodeInfo()
        output_node.name = output_node_names[i]
        output_node.shape = parse_int_array_from_str(output_node_shapes[i])
        option.add_output_node(output_node)

    if FLAGS.check_node != '':
        check_node_names = FLAGS.check_node.split(',')
        check_node_shapes = FLAGS.check_shape.split(':')
        if len(check_node_names) != len(check_node_shapes):
            raise Exception('check node count and shape count do not match.')
        for i in six.moves.range(len(check_node_names)):
            check_node = cvt.NodeInfo()
            check_node.name = check_node_names[i]
            check_node.shape = parse_int_array_from_str(check_node_shapes[i])
            option.add_check_node(check_node)

    option.build()

    print("Transform model to one that can better run on device")
    if FLAGS.runtime == 'dsp' and not option.quantize:
        mace_check(FLAGS.platform == 'tensorflow',
                   'DSP only supports tensorflow')
        from mace.python.tools.converter_tool import tf_dsp_converter
        converter = tf_dsp_converter.TensorflowDspConverter(
            option, FLAGS.model_file)
        output_graph_def = converter.run()
    else:
        if FLAGS.platform == 'tensorflow':
            from mace.python.tools.converter_tool import tensorflow_converter
            converter = tensorflow_converter.TensorflowConverter(
                option, FLAGS.model_file)
        elif FLAGS.platform == 'caffe':
            from mace.python.tools.converter_tool import caffe_converter
            converter = caffe_converter.CaffeConverter(option,
                                                       FLAGS.model_file,
                                                       FLAGS.weight_file)
        else:
            six.print_("Mace do not support platorm %s yet." % FLAGS.platform,
                       file=sys.stderr)
            exit(1)

        output_graph_def = converter.run()

        if FLAGS.runtime == 'cpu+gpu':
            cpu_graph_def = copy.deepcopy(output_graph_def)

            option.device = cvt.DeviceType.GPU.value
            option.data_type = parse_data_type(FLAGS.data_type,
                                               cvt.DeviceType.GPU.value)
            mace_gpu_transformer = transformer.Transformer(
                option, output_graph_def)
            output_graph_def, _ = mace_gpu_transformer.run()
            six.print_("start optimize gpu memory.")
            memory_optimizer.optimize_gpu_memory(output_graph_def)
            six.print_("GPU memory optimization done.")

            option.device = cvt.DeviceType.CPU.value
            option.data_type = parse_data_type(FLAGS.data_type,
                                               cvt.DeviceType.CPU.value)
            option.disable_transpose_filters()
            mace_cpu_transformer = transformer.Transformer(
                option, cpu_graph_def)
            cpu_graph_def, _ = mace_cpu_transformer.run()
            print("start optimize cpu memory.")
            memory_optimizer.optimize_cpu_memory(cpu_graph_def)
            print("CPU memory optimization done.")

            print("Merge cpu and gpu ops together")
            output_graph_def.op.extend(cpu_graph_def.op)
            output_graph_def.mem_arena.mem_block.extend(
                cpu_graph_def.mem_arena.mem_block)
            output_graph_arg_names = set()
            for arg in output_graph_def.arg:
                output_graph_arg_names.add(arg.name)

            for arg in cpu_graph_def.arg:
                if arg.name not in output_graph_arg_names:
                    output_graph_def.arg.extend(arg)
            print("Merge done")
        else:
            option.device = device_type_map[FLAGS.runtime]
            option.data_type = parse_data_type(FLAGS.data_type, option.device)
            mace_transformer = transformer.Transformer(option,
                                                       output_graph_def)
            output_graph_def, quantize_activation_info = mace_transformer.run()

            if FLAGS.runtime == 'dsp':
                from mace.python.tools.converter_tool import hexagon_converter
                converter = hexagon_converter.HexagonConverter(
                    option, output_graph_def, quantize_activation_info)
                output_graph_def = converter.run()

            print("start optimize memory.")
            if FLAGS.runtime == 'gpu':
                memory_optimizer.optimize_gpu_memory(output_graph_def)
            elif FLAGS.runtime == 'cpu':
                memory_optimizer.optimize_cpu_memory(output_graph_def)
            elif FLAGS.runtime == 'dsp':
                pass
            else:
                mace_check(False, "runtime only support [gpu|cpu|dsp]")

            print("Memory optimization done.")

    model_saver.save_model(output_graph_def, model_checksum, weight_checksum,
                           FLAGS.template_dir, FLAGS.obfuscate,
                           FLAGS.model_tag, FLAGS.output_dir, FLAGS.runtime,
                           FLAGS.embed_model_data, FLAGS.winograd,
                           FLAGS.data_type, FLAGS.model_graph_format)
Exemple #2
0
def main(unused_args):
    if not os.path.isfile(FLAGS.model_file):
        six.print_("Input graph file '" + FLAGS.model_file +
                   "' does not exist!",
                   file=sys.stderr)
        sys.exit(-1)

    model_checksum = file_checksum(FLAGS.model_file)
    if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
        six.print_("Model checksum mismatch: %s != %s" %
                   (model_checksum, FLAGS.model_checksum),
                   file=sys.stderr)
        sys.exit(-1)

    weight_checksum = None
    if FLAGS.platform == 'caffe':
        if not os.path.isfile(FLAGS.weight_file):
            six.print_("Input weight file '" + FLAGS.weight_file +
                       "' does not exist!",
                       file=sys.stderr)
            sys.exit(-1)

        weight_checksum = file_checksum(FLAGS.weight_file)
        if FLAGS.weight_checksum != "" and \
                FLAGS.weight_checksum != weight_checksum:
            six.print_("Weight checksum mismatch: %s != %s" %
                       (weight_checksum, FLAGS.weight_checksum),
                       file=sys.stderr)
            sys.exit(-1)

    if FLAGS.platform not in ['tensorflow', 'caffe', 'onnx']:
        six.print_("platform %s is not supported." % FLAGS.platform,
                   file=sys.stderr)
        sys.exit(-1)
    if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', 'hta', 'apu', 'cpu+gpu']:
        six.print_("runtime %s is not supported." % FLAGS.runtime,
                   file=sys.stderr)
        sys.exit(-1)

    option = cvt.ConverterOption()
    if FLAGS.graph_optimize_options:
        option.transformer_option = FLAGS.graph_optimize_options.split(',')
    option.winograd = FLAGS.winograd
    option.quantize = FLAGS.quantize
    option.quantize_large_weights = FLAGS.quantize_large_weights
    option.quantize_range_file = FLAGS.quantize_range_file
    option.change_concat_ranges = FLAGS.change_concat_ranges
    option.cl_mem_type = FLAGS.cl_mem_type
    option.device = device_type_map[FLAGS.runtime]
    option.data_type = parse_data_type(FLAGS.data_type, option.device)

    input_node_names = FLAGS.input_node.split(',')
    input_data_types = FLAGS.input_data_types.split(',')
    input_node_shapes = FLAGS.input_shape.split(':')
    input_node_formats = FLAGS.input_data_formats.split(",")
    if FLAGS.input_range:
        input_node_ranges = FLAGS.input_range.split(':')
    else:
        input_node_ranges = []
    if len(input_node_names) != len(input_node_shapes):
        raise Exception('input node count and shape count do not match.')
    for i in six.moves.range(len(input_node_names)):
        input_node = cvt.NodeInfo()
        input_node.name = input_node_names[i]
        input_node.data_type = data_type_map[input_data_types[i]]
        input_node.data_format = data_format_map[input_node_formats[i]]
        input_node.shape = parse_int_array_from_str(input_node_shapes[i])
        if input_node.data_format == cvt.DataFormat.NCHW and\
                len(input_node.shape) == 4:
            input_node.shape = transpose_shape(input_node.shape, [0, 2, 3, 1])
            input_node.data_format = cvt.DataFormat.NHWC
        if len(input_node_ranges) > i:
            input_node.range = parse_float_array_from_str(input_node_ranges[i])
        option.add_input_node(input_node)

    output_node_names = FLAGS.output_node.split(',')
    output_data_types = FLAGS.output_data_types.split(',')
    output_node_shapes = FLAGS.output_shape.split(':')
    output_node_formats = FLAGS.output_data_formats.split(",")
    if len(output_node_names) != len(output_node_shapes):
        raise Exception('output node count and shape count do not match.')
    for i in six.moves.range(len(output_node_names)):
        output_node = cvt.NodeInfo()
        output_node.name = output_node_names[i]
        output_node.data_type = data_type_map[output_data_types[i]]
        output_node.data_format = data_format_map[output_node_formats[i]]
        output_node.shape = parse_int_array_from_str(output_node_shapes[i])
        if output_node.data_format == cvt.DataFormat.NCHW and\
                len(output_node.shape) == 4:
            output_node.shape = transpose_shape(output_node.shape,
                                                [0, 2, 3, 1])
            output_node.data_format = cvt.DataFormat.NHWC
        option.add_output_node(output_node)

    if FLAGS.check_node != '':
        check_node_names = FLAGS.check_node.split(',')
        check_node_shapes = FLAGS.check_shape.split(':')
        if len(check_node_names) != len(check_node_shapes):
            raise Exception('check node count and shape count do not match.')
        for i in six.moves.range(len(check_node_names)):
            check_node = cvt.NodeInfo()
            check_node.name = check_node_names[i]
            check_node.shape = parse_int_array_from_str(check_node_shapes[i])
            option.add_check_node(check_node)
    else:
        option.check_nodes = option.output_nodes

    option.build()

    print("Transform model to one that can better run on device")
    if FLAGS.platform == 'tensorflow':
        from mace.python.tools.converter_tool import tensorflow_converter
        converter = tensorflow_converter.TensorflowConverter(
            option, FLAGS.model_file)
    elif FLAGS.platform == 'caffe':
        from mace.python.tools.converter_tool import caffe_converter
        converter = caffe_converter.CaffeConverter(option, FLAGS.model_file,
                                                   FLAGS.weight_file)
    elif FLAGS.platform == 'onnx':
        from mace.python.tools.converter_tool import onnx_converter
        converter = onnx_converter.OnnxConverter(option, FLAGS.model_file)
    else:
        six.print_("Mace do not support platorm %s yet." % FLAGS.platform,
                   file=sys.stderr)
        exit(1)

    output_graph_def = converter.run()
    mace_transformer = transformer.Transformer(option, output_graph_def)
    output_graph_def, quantize_activation_info = mace_transformer.run()

    if option.device in [
            cvt.DeviceType.HEXAGON.value, cvt.DeviceType.HTA.value
    ]:
        from mace.python.tools.converter_tool import hexagon_converter
        converter = hexagon_converter.HexagonConverter(
            option, output_graph_def, quantize_activation_info)
        output_graph_def = converter.run()
    elif FLAGS.runtime == 'apu':
        if FLAGS.platform != 'tensorflow':
            raise Exception('apu only support model from tensorflow')
        from mace.python.tools.converter_tool import apu_converter
        converter = apu_converter.ApuConverter(option, output_graph_def,
                                               quantize_activation_info)
        output_graph_def = converter.run()

    try:
        visualizer = visualize_model.ModelVisualizer(FLAGS.model_tag,
                                                     output_graph_def)
        visualizer.save_html()
    except:  # noqa
        print("Failed to visualize model:", sys.exc_info()[0])

    model_saver.save_model(option, output_graph_def, model_checksum,
                           weight_checksum, FLAGS.template_dir,
                           FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output_dir,
                           FLAGS.embed_model_data, FLAGS.winograd,
                           FLAGS.model_graph_format)
Exemple #3
0
def main(unused_args):
    if not os.path.isfile(FLAGS.model_file):
        six.print_("Input graph file '" + FLAGS.model_file +
                   "' does not exist!",
                   file=sys.stderr)
        sys.exit(-1)

    model_checksum = file_checksum(FLAGS.model_file)
    if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
        six.print_("Model checksum mismatch: %s != %s" %
                   (model_checksum, FLAGS.model_checksum),
                   file=sys.stderr)
        sys.exit(-1)

    weight_checksum = None
    if FLAGS.platform == 'caffe':
        if not os.path.isfile(FLAGS.weight_file):
            six.print_("Input weight file '" + FLAGS.weight_file +
                       "' does not exist!",
                       file=sys.stderr)
            sys.exit(-1)

        weight_checksum = file_checksum(FLAGS.weight_file)
        if FLAGS.weight_checksum != "" and \
                FLAGS.weight_checksum != weight_checksum:
            six.print_("Weight checksum mismatch: %s != %s" %
                       (weight_checksum, FLAGS.weight_checksum),
                       file=sys.stderr)
            sys.exit(-1)

    if FLAGS.platform not in ['tensorflow', 'caffe']:
        six.print_("platform %s is not supported." % FLAGS.platform,
                   file=sys.stderr)
        sys.exit(-1)
    if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', 'cpu+gpu']:
        six.print_("runtime %s is not supported." % FLAGS.runtime,
                   file=sys.stderr)
        sys.exit(-1)

    option = cvt.ConverterOption()
    if FLAGS.graph_optimize_options:
        option.transformer_option = FLAGS.graph_optimize_options.split(',')
    option.winograd = FLAGS.winograd
    option.quantize = FLAGS.quantize
    option.quantize_range_file = FLAGS.quantize_range_file
    option.change_concat_ranges = FLAGS.change_concat_ranges
    option.cl_mem_type = FLAGS.cl_mem_type

    input_node_names = FLAGS.input_node.split(',')
    input_node_shapes = FLAGS.input_shape.split(':')
    input_node_formats = FLAGS.input_data_formats.split(",")
    if FLAGS.input_range:
        input_node_ranges = FLAGS.input_range.split(':')
    else:
        input_node_ranges = []
    if len(input_node_names) != len(input_node_shapes):
        raise Exception('input node count and shape count do not match.')
    for i in six.moves.range(len(input_node_names)):
        input_node = cvt.NodeInfo()
        input_node.name = input_node_names[i]
        if len(input_node_formats) == 1:
            input_node.data_format = data_format_map[input_node_formats[0]]
        else:
            input_node.data_format = data_format_map[input_node_formats[i]]
        input_node.shape = parse_int_array_from_str(input_node_shapes[i])
        if len(input_node_ranges) > i:
            input_node.range = parse_float_array_from_str(input_node_ranges[i])
        option.add_input_node(input_node)

    output_node_names = FLAGS.output_node.split(',')
    output_node_shapes = FLAGS.output_shape.split(':')
    output_node_formats = FLAGS.output_data_formats.split(",")
    if len(output_node_names) != len(output_node_shapes):
        raise Exception('output node count and shape count do not match.')
    for i in six.moves.range(len(output_node_names)):
        output_node = cvt.NodeInfo()
        output_node.name = output_node_names[i]
        if len(output_node_formats) == 1:
            output_node.data_format = data_format_map[output_node_formats[0]]
        else:
            output_node.data_format = data_format_map[output_node_formats[i]]
        output_node.shape = parse_int_array_from_str(output_node_shapes[i])
        option.add_output_node(output_node)

    if FLAGS.check_node != '':
        check_node_names = FLAGS.check_node.split(',')
        check_node_shapes = FLAGS.check_shape.split(':')
        if len(check_node_names) != len(check_node_shapes):
            raise Exception('check node count and shape count do not match.')
        for i in six.moves.range(len(check_node_names)):
            check_node = cvt.NodeInfo()
            check_node.name = check_node_names[i]
            check_node.shape = parse_int_array_from_str(check_node_shapes[i])
            option.add_check_node(check_node)

    option.build()

    print("Transform model to one that can better run on device")
    if FLAGS.runtime == 'dsp' and not option.quantize:
        mace_check(FLAGS.platform == 'tensorflow',
                   'DSP only supports tensorflow')
        from mace.python.tools.converter_tool import tf_dsp_converter
        converter = tf_dsp_converter.TensorflowDspConverter(
            option, FLAGS.model_file)
        output_graph_def = converter.run()
    else:
        if FLAGS.platform == 'tensorflow':
            from mace.python.tools.converter_tool import tensorflow_converter
            converter = tensorflow_converter.TensorflowConverter(
                option, FLAGS.model_file)
        elif FLAGS.platform == 'caffe':
            from mace.python.tools.converter_tool import caffe_converter
            converter = caffe_converter.CaffeConverter(option,
                                                       FLAGS.model_file,
                                                       FLAGS.weight_file)
        else:
            six.print_("Mace do not support platorm %s yet." % FLAGS.platform,
                       file=sys.stderr)
            exit(1)

        output_graph_def = converter.run()

        option.device = device_type_map[FLAGS.runtime]
        option.data_type = parse_data_type(FLAGS.data_type, option.device)
        mace_transformer = transformer.Transformer(option, output_graph_def)
        output_graph_def, quantize_activation_info = mace_transformer.run()

        if FLAGS.runtime == 'dsp':
            from mace.python.tools.converter_tool import hexagon_converter
            converter = hexagon_converter.HexagonConverter(
                option, output_graph_def, quantize_activation_info)
            output_graph_def = converter.run()

    model_saver.save_model(option, output_graph_def, model_checksum,
                           weight_checksum, FLAGS.template_dir,
                           FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output_dir,
                           FLAGS.embed_model_data, FLAGS.winograd,
                           FLAGS.model_graph_format)