def testConvertBytes(self): source, header = util.convert_bytes_to_c_source( b"\x00\x01\x02\x23", "foo", 16, use_tensorflow_license=False) self.assertTrue( source.find("const unsigned char foo[] DATA_ALIGN_ATTRIBUTE = {")) self.assertTrue(source.find(""" 0x00, 0x01, 0x02, 0x23,""")) self.assertNotEqual(-1, source.find("const int foo_len = 4;")) self.assertEqual(-1, source.find("/* Copyright")) self.assertEqual(-1, source.find("#include " "")) self.assertNotEqual(-1, header.find("extern const unsigned char foo[];")) self.assertNotEqual(-1, header.find("extern const int foo_len;")) self.assertEqual(-1, header.find("/* Copyright")) source, header = util.convert_bytes_to_c_source( b"\xff\xfe\xfd\xfc", "bar", 80, include_guard="MY_GUARD", include_path="my/guard.h", use_tensorflow_license=True) self.assertNotEqual( -1, source.find("const unsigned char bar[] DATA_ALIGN_ATTRIBUTE = {")) self.assertNotEqual(-1, source.find(""" 0xff, 0xfe, 0xfd, 0xfc,""")) self.assertNotEqual(-1, source.find("/* Copyright")) self.assertNotEqual(-1, source.find("#include \"my/guard.h\"")) self.assertNotEqual(-1, header.find("#ifndef MY_GUARD")) self.assertNotEqual(-1, header.find("#define MY_GUARD")) self.assertNotEqual(-1, header.find("/* Copyright"))
def write_model(model_object, output_tflite_file, include_path, array_name): """Writes the tflite model, a python object, into the output file. Args: model_object: A tflite model as a python object output_tflite_file: Full path name to the output tflite file. include_path: Path to model header file array_name: name of the array for .cc output Raises: ValueError: If file is not formatted in .cc or .tflite """ model_bytearray = _convert_model_from_object_to_bytearray(model_object) if output_tflite_file.endswith('.cc'): mode = 'w' converted_model = convert_bytes_to_c_source( data=model_bytearray, array_name=array_name, include_path=include_path, use_tensorflow_license=True)[0] elif output_tflite_file.endswith('.tflite'): mode = 'wb' converted_model = model_bytearray else: raise ValueError('File format not supported') with open(output_tflite_file, mode) as output_file: output_file.write(converted_model)
def main(_): with open(FLAGS.input_tflite_file, "rb") as input_handle: input_data = input_handle.read() source, header = util.convert_bytes_to_c_source( data=input_data, array_name=FLAGS.array_variable_name, max_line_width=FLAGS.line_width, include_guard=FLAGS.include_guard, include_path=FLAGS.include_path, use_tensorflow_license=FLAGS.use_tensorflow_license) with open(FLAGS.output_source_file, "w") as source_handle: source_handle.write(source) with open(FLAGS.output_header_file, "w") as header_handle: header_handle.write(header)
def run_main(_): """Main in convert_file_to_c_source.py.""" parser = argparse.ArgumentParser( description=("Command line tool to run TensorFlow Lite Converter.")) parser.add_argument( "--input_tflite_file", type=str, help="Full filepath of the input TensorFlow Lite file.", required=True) parser.add_argument("--output_source_file", type=str, help="Full filepath of the output C source file.", required=True) parser.add_argument("--output_header_file", type=str, help="Full filepath of the output C header file.", required=True) parser.add_argument("--array_variable_name", type=str, help="Name to use for the C data array variable.", required=True) parser.add_argument("--line_width", type=int, help="Width to use for formatting.", default=80) parser.add_argument("--include_guard", type=str, help="Name to use for the C header include guard.", default=None) parser.add_argument( "--include_path", type=str, help="Optional path to include in generated source file.", default=None) parser.add_argument( "--use_tensorflow_license", dest="use_tensorflow_license", help= "Whether to prefix the generated files with the TF Apache2 license.", action="store_true") parser.set_defaults(use_tensorflow_license=False) flags, _ = parser.parse_known_args(args=sys.argv[1:]) with open(flags.input_tflite_file, "rb") as input_handle: input_data = input_handle.read() source, header = util.convert_bytes_to_c_source( data=input_data, array_name=flags.array_variable_name, max_line_width=flags.line_width, include_guard=flags.include_guard, include_path=flags.include_path, use_tensorflow_license=flags.use_tensorflow_license) with open(flags.output_source_file, "w") as source_handle: source_handle.write(source) with open(flags.output_header_file, "w") as header_handle: header_handle.write(header)