def main(_): """Application run loop.""" parser = argparse.ArgumentParser( description= "Modify a quantized model's interface from float to integer.") parser.add_argument('--input_file', type=str, required=True, help='Full path name to the input tflite file.') parser.add_argument('--output_file', type=str, required=True, help='Full path name to the output tflite file.') parser.add_argument('--input_type', type=str.upper, choices=mmi_constants.STR_TYPES, default=mmi_constants.DEFAULT_STR_TYPE, help='Modified input integer interface type.') parser.add_argument('--output_type', type=str.upper, choices=mmi_constants.STR_TYPES, default=mmi_constants.DEFAULT_STR_TYPE, help='Modified output integer interface type.') args = parser.parse_args() input_type = mmi_constants.STR_TO_TFLITE_TYPES[args.input_type] output_type = mmi_constants.STR_TO_TFLITE_TYPES[args.output_type] mmi_lib.modify_model_interface(args.input_file, args.output_file, input_type, output_type) print('Successfully modified the model input type from FLOAT to ' '{input_type} and output type from FLOAT to {output_type}.'.format( input_type=args.input_type, output_type=args.output_type))
def testInt8Interface(self): # 1. SETUP # Define the temporary directory and files temp_dir = self.get_temp_dir() initial_file = os.path.join(temp_dir, 'initial_model.tflite') final_file = os.path.join(temp_dir, 'final_model.tflite') # Define initial model initial_model = build_tflite_model_with_full_integer_quantization() with open(initial_file, 'wb') as model_file: model_file.write(initial_model) # 2. INVOKE # Invoke the modify_model_interface function modify_model_interface_lib.modify_model_interface(initial_file, final_file, tf.int8, tf.int8) # 3. VALIDATE # Load TFLite model and allocate tensors. initial_interpreter = tf.lite.Interpreter(model_path=initial_file) initial_interpreter.allocate_tensors() final_interpreter = tf.lite.Interpreter(model_path=final_file) final_interpreter.allocate_tensors() # Get input and output types. initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype'] initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype'] final_input_dtype = final_interpreter.get_input_details()[0]['dtype'] final_output_dtype = final_interpreter.get_output_details()[0]['dtype'] # Validate the model interfaces self.assertEqual(initial_input_dtype, np.float32) self.assertEqual(initial_output_dtype, np.float32) self.assertEqual(final_input_dtype, np.int8) self.assertEqual(final_output_dtype, np.int8)
def main(_): input_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.input_type] output_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.output_type] mmi_lib.modify_model_interface(FLAGS.input_file, FLAGS.output_file, input_type, output_type) print('Successfully modified the model input type from FLOAT to ' '{input_type} and output type from FLOAT to {output_type}.'.format( input_type=FLAGS.input_type, output_type=FLAGS.output_type))