def testStripStrings(self): # 1. SETUP # Define the initial model initial_model = test_utils.build_mock_model() final_model = copy.deepcopy(initial_model) # 2. INVOKE # Invoke the strip_strings function flatbuffer_utils.strip_strings(final_model) # 3. VALIDATE # Validate that the initial and final models are the same except strings # Validate the description self.assertIsNotNone(initial_model.description) self.assertIsNone(final_model.description) self.assertIsNotNone(initial_model.signatureDefs) self.assertIsNone(final_model.signatureDefs) # Validate the main subgraph's name, inputs, outputs, operators and tensors initial_subgraph = initial_model.subgraphs[0] final_subgraph = final_model.subgraphs[0] self.assertIsNotNone(initial_model.subgraphs[0].name) self.assertIsNone(final_model.subgraphs[0].name) for i in range(len(initial_subgraph.inputs)): self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) for i in range(len(initial_subgraph.outputs)): self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) for i in range(len(initial_subgraph.operators)): self.assertEqual(initial_subgraph.operators[i].opcodeIndex, final_subgraph.operators[i].opcodeIndex) initial_tensors = initial_subgraph.tensors final_tensors = final_subgraph.tensors for i in range(len(initial_tensors)): self.assertIsNotNone(initial_tensors[i].name) self.assertIsNone(final_tensors[i].name) self.assertEqual(initial_tensors[i].type, final_tensors[i].type) self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) for j in range(len(initial_tensors[i].shape)): self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) # Validate the first valid buffer (index 0 is always None) initial_buffer = initial_model.buffers[1].data final_buffer = final_model.buffers[1].data for i in range(initial_buffer.size): self.assertEqual(initial_buffer.data[i], final_buffer.data[i])
def main(_): """Application run loop.""" parser = argparse.ArgumentParser( description='Strips all nonessential strings from a tflite file.') parser.add_argument('--input_tflite_file', type=str, required=True, help='Full path name to the input tflite file.') parser.add_argument( '--output_tflite_file', type=str, required=True, help='Full path name to the stripped output tflite file.') args = parser.parse_args() # Read the model model = flatbuffer_utils.read_model(args.input_tflite_file) # Invoke the strip tflite file function flatbuffer_utils.strip_strings(model) # Write the model flatbuffer_utils.write_model(model, args.output_tflite_file)
def main(_): model = flatbuffer_utils.read_model(FLAGS.input_tflite_file) flatbuffer_utils.strip_strings(model) flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)