コード例 #1
0
def main(_):
    """Application run loop."""
    parser = argparse.ArgumentParser(
        description=
        "Modify a quantized model's interface from float to integer.")
    parser.add_argument('--input_file',
                        type=str,
                        required=True,
                        help='Full path name to the input tflite file.')
    parser.add_argument('--output_file',
                        type=str,
                        required=True,
                        help='Full path name to the output tflite file.')
    parser.add_argument('--input_type',
                        type=str.upper,
                        choices=mmi_constants.STR_TYPES,
                        default=mmi_constants.DEFAULT_STR_TYPE,
                        help='Modified input integer interface type.')
    parser.add_argument('--output_type',
                        type=str.upper,
                        choices=mmi_constants.STR_TYPES,
                        default=mmi_constants.DEFAULT_STR_TYPE,
                        help='Modified output integer interface type.')
    args = parser.parse_args()

    input_type = mmi_constants.STR_TO_TFLITE_TYPES[args.input_type]
    output_type = mmi_constants.STR_TO_TFLITE_TYPES[args.output_type]

    mmi_lib.modify_model_interface(args.input_file, args.output_file,
                                   input_type, output_type)

    print('Successfully modified the model input type from FLOAT to '
          '{input_type} and output type from FLOAT to {output_type}.'.format(
              input_type=args.input_type, output_type=args.output_type))
コード例 #2
0
  def testInt8Interface(self):
    # 1. SETUP
    # Define the temporary directory and files
    temp_dir = self.get_temp_dir()
    initial_file = os.path.join(temp_dir, 'initial_model.tflite')
    final_file = os.path.join(temp_dir, 'final_model.tflite')
    # Define initial model
    initial_model = build_tflite_model_with_full_integer_quantization()
    with open(initial_file, 'wb') as model_file:
      model_file.write(initial_model)

    # 2. INVOKE
    # Invoke the modify_model_interface function
    modify_model_interface_lib.modify_model_interface(initial_file, final_file,
                                                      tf.int8, tf.int8)

    # 3. VALIDATE
    # Load TFLite model and allocate tensors.
    initial_interpreter = tf.lite.Interpreter(model_path=initial_file)
    initial_interpreter.allocate_tensors()
    final_interpreter = tf.lite.Interpreter(model_path=final_file)
    final_interpreter.allocate_tensors()

    # Get input and output types.
    initial_input_dtype = initial_interpreter.get_input_details()[0]['dtype']
    initial_output_dtype = initial_interpreter.get_output_details()[0]['dtype']
    final_input_dtype = final_interpreter.get_input_details()[0]['dtype']
    final_output_dtype = final_interpreter.get_output_details()[0]['dtype']

    # Validate the model interfaces
    self.assertEqual(initial_input_dtype, np.float32)
    self.assertEqual(initial_output_dtype, np.float32)
    self.assertEqual(final_input_dtype, np.int8)
    self.assertEqual(final_output_dtype, np.int8)
コード例 #3
0
def main(_):
    input_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.input_type]
    output_type = mmi_constants.STR_TO_TFLITE_TYPES[FLAGS.output_type]

    mmi_lib.modify_model_interface(FLAGS.input_file, FLAGS.output_file,
                                   input_type, output_type)

    print('Successfully modified the model input type from FLOAT to '
          '{input_type} and output type from FLOAT to {output_type}.'.format(
              input_type=FLAGS.input_type, output_type=FLAGS.output_type))