def main(_): parser = argparse.ArgumentParser( description='Randomize weights in a tflite file.') parser.add_argument('--input_tflite_file', type=str, required=True, help='Full path name to the input tflite file.') parser.add_argument( '--output_tflite_file', type=str, required=True, help='Full path name to the output randomized tflite file.') parser.add_argument( '--random_seed', type=str, required=False, default=0, help='Input to the random number generator. The default value is 0.') args = parser.parse_args() # Read the model input_model = flatbuffer_utils.read_model(args.input_tflite_file) # Invoke the randomize weights function output_model = flatbuffer_utils.randomize_weights(input_model, args.random_seed) # Write the model flatbuffer_utils.write_model(output_model, args.output_tflite_file)
def testXxdOutputToBytes(self): # 1. SETUP # Define the initial model initial_model = test_utils.build_mock_model() initial_bytes = flatbuffer_utils.convert_object_to_bytearray( initial_model) # Define temporary files tmp_dir = self.get_temp_dir() model_filename = os.path.join(tmp_dir, 'model.tflite') # 2. Write model to temporary file (will be used as input for xxd) flatbuffer_utils.write_model(initial_model, model_filename) # 3. DUMP WITH xxd input_cc_file = os.path.join(tmp_dir, 'model.cc') command = 'xxd -i {} > {}'.format(model_filename, input_cc_file) subprocess.call(command, shell=True) # 4. VALIDATE final_bytes = flatbuffer_utils.xxd_output_to_bytes(input_cc_file) # Validate that the initial and final bytearray are the same self.assertEqual(initial_bytes, final_bytes)
def tflite_graph_rewrite(tflite_model_path, saved_model_dir, custom_op_registerers=None): """Rewrite TFLite graph to make inputs/outputs tensor name consistent. TF users do not have good control over outputs tensor names from get_concrete_function(), to maintain backward compatibility the tensor name in TFLite graph need to be meaningful and properly set. This function looks up the meaningful names from SavedModel signature meta data and rewrite it into TFlite graph. Arguments: tflite_model_path: The path to the exported TFLite graph, which will be overwrite after rewrite. saved_model_dir: Directory that stores SavedModelthat used for TFLite custom_op_registerers: list with custom op registers conversion. """ # Find map from signature inputs/outputs name to tensor name in SavedModel. meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir, "serve") signature_def = meta_graph_def.signature_def tensor_name_to_signature_name = {} for key, value in signature_def["serving_default"].inputs.items(): tensor_name_to_signature_name[value.name] = key for key, value in signature_def["serving_default"].outputs.items(): tensor_name_to_signature_name[value.name] = key # Find map from TFlite inputs/outputs index to tensor name in TFLite graph. with tf.io.gfile.GFile(tflite_model_path, "rb") as f: interpreter = interpreter_wrapper.InterpreterWithCustomOps( model_content=f.read(), custom_op_registerers=custom_op_registerers) tflite_input_index_to_tensor_name = {} tflite_output_index_to_tensor_name = {} for idx, input_detail in enumerate(interpreter.get_input_details()): tflite_input_index_to_tensor_name[idx] = input_detail["name"] for idx, output_detail in enumerate(interpreter.get_output_details()): tflite_output_index_to_tensor_name[idx] = output_detail["name"] # Rewrite TFLite graph inputs/outputs name. mutable_fb = flatbuffer_utils.read_model_with_mutable_tensors( tflite_model_path) subgraph = mutable_fb.subgraphs[0] for input_idx, input_tensor_name in tflite_input_index_to_tensor_name.items(): subgraph.tensors[subgraph.inputs[ input_idx]].name = tensor_name_to_signature_name[input_tensor_name] for output_idx, output_tensor_name in tflite_output_index_to_tensor_name.items( ): subgraph.tensors[subgraph.outputs[ output_idx]].name = tensor_name_to_signature_name[output_tensor_name] flatbuffer_utils.write_model(mutable_fb, tflite_model_path)
def testWriteReadModel(self): # 1. SETUP # Define the initial model initial_model = test_utils.build_mock_model_python_object() # Define temporary files tmp_dir = self.get_temp_dir() model_filename = os.path.join(tmp_dir, 'model.tflite') # 2. INVOKE # Invoke the write_model and read_model functions flatbuffer_utils.write_model(initial_model, model_filename) final_model = flatbuffer_utils.read_model(model_filename) # 3. VALIDATE # Validate that the initial and final models are the same # Validate the description self.assertEqual(initial_model.description, final_model.description) # Validate the main subgraph's name, inputs, outputs, operators and tensors initial_subgraph = initial_model.subgraphs[0] final_subgraph = final_model.subgraphs[0] self.assertEqual(initial_subgraph.name, final_subgraph.name) for i in range(len(initial_subgraph.inputs)): self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) for i in range(len(initial_subgraph.outputs)): self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) for i in range(len(initial_subgraph.operators)): self.assertEqual(initial_subgraph.operators[i].opcodeIndex, final_subgraph.operators[i].opcodeIndex) initial_tensors = initial_subgraph.tensors final_tensors = final_subgraph.tensors for i in range(len(initial_tensors)): self.assertEqual(initial_tensors[i].name, final_tensors[i].name) self.assertEqual(initial_tensors[i].type, final_tensors[i].type) self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) for j in range(len(initial_tensors[i].shape)): self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) # Validate the first valid buffer (index 0 is always None) initial_buffer = initial_model.buffers[1].data final_buffer = final_model.buffers[1].data for i in range(initial_buffer.size): self.assertEqual(initial_buffer.data[i], final_buffer.data[i])
def main(_): """Application run loop.""" parser = argparse.ArgumentParser( description='Reverses xxd dump from to binary file') parser.add_argument('--input_cc_file', type=str, required=True, help='Full path name to the input cc file.') parser.add_argument( '--output_tflite_file', type=str, required=True, help='Full path name to the stripped output tflite file.') args = parser.parse_args() # Read the model from xxd output C++ source file model = flatbuffer_utils.xxd_output_to_object(args.input_cc_file) # Write the model flatbuffer_utils.write_model(model, args.output_tflite_file)
def main(_): """Application run loop.""" parser = argparse.ArgumentParser( description='Strips all nonessential strings from a tflite file.') parser.add_argument('--input_tflite_file', type=str, required=True, help='Full path name to the input tflite file.') parser.add_argument( '--output_tflite_file', type=str, required=True, help='Full path name to the stripped output tflite file.') args = parser.parse_args() # Read the model model = flatbuffer_utils.read_model(args.input_tflite_file) # Invoke the strip tflite file function flatbuffer_utils.strip_strings(model) # Write the model flatbuffer_utils.write_model(model, args.output_tflite_file)
def main(_): model = flatbuffer_utils.read_model(FLAGS.input_tflite_file) flatbuffer_utils.strip_strings(model) flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)
def main(_): model = flatbuffer_utils.read_model(FLAGS.input_tflite_file) flatbuffer_utils.randomize_weights(model, FLAGS.random_seed) flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)
def main(_): model = flatbuffer_utils.xxd_output_to_object(FLAGS.input_cc_file) flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)