Пример #1
0
def _evaluate_tflite_model(tflite_model, input_data):
    """Returns evaluation of input data on TFLite model.

  Args:
    tflite_model: Serialized TensorFlow Lite model.
    input_data: List of np.ndarray.

  Returns:
    List of np.ndarray.
  """
    interpreter = _lite.Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    for input_tensor, tensor_data in zip(input_details, input_data):
        interpreter.set_tensor(input_tensor["index"], tensor_data)

    interpreter.invoke()
    output_data = [
        interpreter.get_tensor(output_tensor["index"])
        for output_tensor in output_details
    ]
    return output_data
Пример #2
0
def _generate_random_input_data(tflite_model, seed=None):
    """Generates input data based on the input tensors in the TFLite model.

  Args:
    tflite_model: Serialized TensorFlow Lite model.
    seed: Integer seed for the random generator. (default None)

  Returns:
    List of np.ndarray.
  """
    interpreter = _lite.Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()

    if seed:
        np.random.seed(seed=seed)
    return [
        np.array(np.random.random_sample(input_tensor["shape"]),
                 dtype=input_tensor["dtype"]) for input_tensor in input_details
    ]
def test_frozen_graph_quant(filename,
                            input_arrays,
                            output_arrays,
                            input_shapes=None,
                            **kwargs):
  """Sanity check to validate post quantize flag alters the graph.

  This test does not check correctness of the converted model. It converts the
  TensorFlow frozen graph to TFLite with and without the post_training_quantized
  flag. It ensures some tensors have different types between the float and
  quantized models in the case of an all TFLite model or mix-and-match model.
  It ensures tensor types do not change in the case of an all Flex model.

  Args:
    filename: Full filepath of file containing frozen GraphDef.
    input_arrays: List of input tensors to freeze graph with.
    output_arrays: List of output tensors to freeze graph with.
    input_shapes: Dict of strings representing input tensor names to list of
      integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
      Automatically determined when input shapes is None (e.g., {"foo" : None}).
        (default None)
    **kwargs: Additional arguments to be passed into the converter.

  Raises:
    ValueError: post_training_quantize flag doesn't act as intended.
  """
  # Convert and load the float model.
  converter = _lite.TFLiteConverter.from_frozen_graph(
      filename, input_arrays, output_arrays, input_shapes)
  tflite_model_float = _convert(converter, **kwargs)

  interpreter_float = _lite.Interpreter(model_content=tflite_model_float)
  interpreter_float.allocate_tensors()
  float_tensors = interpreter_float.get_tensor_details()

  # Convert and load the quantized model.
  converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays,
                                                      output_arrays)
  tflite_model_quant = _convert(
      converter, post_training_quantize=True, **kwargs)

  interpreter_quant = _lite.Interpreter(model_content=tflite_model_quant)
  interpreter_quant.allocate_tensors()
  quant_tensors = interpreter_quant.get_tensor_details()
  quant_tensors_map = {
      tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors
  }

  # Check if weights are of different types in the float and quantized models.
  num_tensors_float = len(float_tensors)
  num_tensors_same_dtypes = sum(
      float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"]
      for float_tensor in float_tensors)
  has_quant_tensor = num_tensors_float != num_tensors_same_dtypes

  if ("converter_mode" in kwargs and
      kwargs["converter_mode"] == _lite.ConverterMode.TOCO_FLEX_ALL):
    if has_quant_tensor:
      raise ValueError("--post_training_quantize flag unexpectedly altered the "
                       "full Flex mode graph.")
  elif not has_quant_tensor:
    raise ValueError("--post_training_quantize flag was unable to quantize the "
                     "graph as expected in TFLite and mix-and-match mode.")