示例#1
0
    def testConvertBytes(self):
        source, header = util.convert_bytes_to_c_source(
            b"\x00\x01\x02\x23", "foo", 16, use_tensorflow_license=False)
        self.assertTrue(
            source.find("const unsigned char foo[] DATA_ALIGN_ATTRIBUTE = {"))
        self.assertTrue(source.find("""    0x00, 0x01,
    0x02, 0x23,"""))
        self.assertNotEqual(-1, source.find("const int foo_len = 4;"))
        self.assertEqual(-1, source.find("/* Copyright"))
        self.assertEqual(-1, source.find("#include " ""))
        self.assertNotEqual(-1,
                            header.find("extern const unsigned char foo[];"))
        self.assertNotEqual(-1, header.find("extern const int foo_len;"))
        self.assertEqual(-1, header.find("/* Copyright"))

        source, header = util.convert_bytes_to_c_source(
            b"\xff\xfe\xfd\xfc",
            "bar",
            80,
            include_guard="MY_GUARD",
            include_path="my/guard.h",
            use_tensorflow_license=True)
        self.assertNotEqual(
            -1,
            source.find("const unsigned char bar[] DATA_ALIGN_ATTRIBUTE = {"))
        self.assertNotEqual(-1, source.find("""    0xff, 0xfe, 0xfd, 0xfc,"""))
        self.assertNotEqual(-1, source.find("/* Copyright"))
        self.assertNotEqual(-1, source.find("#include \"my/guard.h\""))
        self.assertNotEqual(-1, header.find("#ifndef MY_GUARD"))
        self.assertNotEqual(-1, header.find("#define MY_GUARD"))
        self.assertNotEqual(-1, header.find("/* Copyright"))
示例#2
0
def write_model(model_object, output_tflite_file, include_path, array_name):
    """Writes the tflite model, a python object, into the output file.

  Args:
    model_object: A tflite model as a python object
    output_tflite_file: Full path name to the output tflite file.
    include_path: Path to model header file
    array_name: name of the array for .cc output

  Raises:
    ValueError: If file is not formatted in .cc or .tflite
  """
    model_bytearray = _convert_model_from_object_to_bytearray(model_object)
    if output_tflite_file.endswith('.cc'):
        mode = 'w'
        converted_model = convert_bytes_to_c_source(
            data=model_bytearray,
            array_name=array_name,
            include_path=include_path,
            use_tensorflow_license=True)[0]
    elif output_tflite_file.endswith('.tflite'):
        mode = 'wb'
        converted_model = model_bytearray
    else:
        raise ValueError('File format not supported')

    with open(output_tflite_file, mode) as output_file:
        output_file.write(converted_model)
示例#3
0
def main(_):
    with open(FLAGS.input_tflite_file, "rb") as input_handle:
        input_data = input_handle.read()

    source, header = util.convert_bytes_to_c_source(
        data=input_data,
        array_name=FLAGS.array_variable_name,
        max_line_width=FLAGS.line_width,
        include_guard=FLAGS.include_guard,
        include_path=FLAGS.include_path,
        use_tensorflow_license=FLAGS.use_tensorflow_license)

    with open(FLAGS.output_source_file, "w") as source_handle:
        source_handle.write(source)

    with open(FLAGS.output_header_file, "w") as header_handle:
        header_handle.write(header)
def run_main(_):
    """Main in convert_file_to_c_source.py."""

    parser = argparse.ArgumentParser(
        description=("Command line tool to run TensorFlow Lite Converter."))

    parser.add_argument(
        "--input_tflite_file",
        type=str,
        help="Full filepath of the input TensorFlow Lite file.",
        required=True)

    parser.add_argument("--output_source_file",
                        type=str,
                        help="Full filepath of the output C source file.",
                        required=True)

    parser.add_argument("--output_header_file",
                        type=str,
                        help="Full filepath of the output C header file.",
                        required=True)

    parser.add_argument("--array_variable_name",
                        type=str,
                        help="Name to use for the C data array variable.",
                        required=True)

    parser.add_argument("--line_width",
                        type=int,
                        help="Width to use for formatting.",
                        default=80)

    parser.add_argument("--include_guard",
                        type=str,
                        help="Name to use for the C header include guard.",
                        default=None)

    parser.add_argument(
        "--include_path",
        type=str,
        help="Optional path to include in generated source file.",
        default=None)

    parser.add_argument(
        "--use_tensorflow_license",
        dest="use_tensorflow_license",
        help=
        "Whether to prefix the generated files with the TF Apache2 license.",
        action="store_true")
    parser.set_defaults(use_tensorflow_license=False)

    flags, _ = parser.parse_known_args(args=sys.argv[1:])

    with open(flags.input_tflite_file, "rb") as input_handle:
        input_data = input_handle.read()

    source, header = util.convert_bytes_to_c_source(
        data=input_data,
        array_name=flags.array_variable_name,
        max_line_width=flags.line_width,
        include_guard=flags.include_guard,
        include_path=flags.include_path,
        use_tensorflow_license=flags.use_tensorflow_license)

    with open(flags.output_source_file, "w") as source_handle:
        source_handle.write(source)

    with open(flags.output_header_file, "w") as header_handle:
        header_handle.write(header)