def main(_): parser = argparse.ArgumentParser( description='Randomize weights in a tflite file.') parser.add_argument('--input_tflite_file', type=str, required=True, help='Full path name to the input tflite file.') parser.add_argument( '--output_tflite_file', type=str, required=True, help='Full path name to the output randomized tflite file.') parser.add_argument( '--random_seed', type=str, required=False, default=0, help='Input to the random number generator. The default value is 0.') args = parser.parse_args() # Read the model input_model = flatbuffer_utils.read_model(args.input_tflite_file) # Invoke the randomize weights function output_model = flatbuffer_utils.randomize_weights(input_model, args.random_seed) # Write the model flatbuffer_utils.write_model(output_model, args.output_tflite_file)
def testWriteReadModel(self): # 1. SETUP # Define the initial model initial_model = test_utils.build_mock_model_python_object() # Define temporary files tmp_dir = self.get_temp_dir() model_filename = os.path.join(tmp_dir, 'model.tflite') # 2. INVOKE # Invoke the write_model and read_model functions flatbuffer_utils.write_model(initial_model, model_filename) final_model = flatbuffer_utils.read_model(model_filename) # 3. VALIDATE # Validate that the initial and final models are the same # Validate the description self.assertEqual(initial_model.description, final_model.description) # Validate the main subgraph's name, inputs, outputs, operators and tensors initial_subgraph = initial_model.subgraphs[0] final_subgraph = final_model.subgraphs[0] self.assertEqual(initial_subgraph.name, final_subgraph.name) for i in range(len(initial_subgraph.inputs)): self.assertEqual(initial_subgraph.inputs[i], final_subgraph.inputs[i]) for i in range(len(initial_subgraph.outputs)): self.assertEqual(initial_subgraph.outputs[i], final_subgraph.outputs[i]) for i in range(len(initial_subgraph.operators)): self.assertEqual(initial_subgraph.operators[i].opcodeIndex, final_subgraph.operators[i].opcodeIndex) initial_tensors = initial_subgraph.tensors final_tensors = final_subgraph.tensors for i in range(len(initial_tensors)): self.assertEqual(initial_tensors[i].name, final_tensors[i].name) self.assertEqual(initial_tensors[i].type, final_tensors[i].type) self.assertEqual(initial_tensors[i].buffer, final_tensors[i].buffer) for j in range(len(initial_tensors[i].shape)): self.assertEqual(initial_tensors[i].shape[j], final_tensors[i].shape[j]) # Validate the first valid buffer (index 0 is always None) initial_buffer = initial_model.buffers[1].data final_buffer = final_model.buffers[1].data for i in range(initial_buffer.size): self.assertEqual(initial_buffer.data[i], final_buffer.data[i])
def main(_): """Application run loop.""" parser = argparse.ArgumentParser( description='Strips all nonessential strings from a tflite file.') parser.add_argument('--input_tflite_file', type=str, required=True, help='Full path name to the input tflite file.') parser.add_argument( '--output_tflite_file', type=str, required=True, help='Full path name to the stripped output tflite file.') args = parser.parse_args() # Read the model model = flatbuffer_utils.read_model(args.input_tflite_file) # Invoke the strip tflite file function flatbuffer_utils.strip_strings(model) # Write the model flatbuffer_utils.write_model(model, args.output_tflite_file)
def main(_): model = flatbuffer_utils.read_model(FLAGS.input_tflite_file) flatbuffer_utils.strip_strings(model) flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)
def main(_): model = flatbuffer_utils.read_model(FLAGS.input_tflite_file) flatbuffer_utils.randomize_weights(model, FLAGS.random_seed) flatbuffer_utils.write_model(model, FLAGS.output_tflite_file)