Ejemplo n.º 1
0
def get_dequantize_opcode_idx(model):
    """Returns the quantize op idx."""
    quant_opcode_idxs = []
    for idx, opcode in enumerate(model.operatorCodes):
        builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
        if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
            quant_opcode_idxs.append(idx)
    return quant_opcode_idxs
Ejemplo n.º 2
0
def get_ops_list(model_data):
    """Return a set of ops in the tflite model data."""
    model = schema_fb.Model.GetRootAsModel(model_data, 0)
    op_set = set()

    for subgraph_idx in range(model.SubgraphsLength()):
        subgraph = model.Subgraphs(subgraph_idx)
        for op_idx in range(subgraph.OperatorsLength()):
            op = subgraph.Operators(op_idx)
            opcode = model.OperatorCodes(op.OpcodeIndex())
            builtin_code = schema_util.get_builtin_code_from_operator_code(
                opcode)
            if builtin_code == schema_fb.BuiltinOperator.CUSTOM:
                opname = opcode.CustomCode().decode("utf-8")
                op_set.add(opname)
            else:
                op_set.add(visualize.BuiltinCodeToName(builtin_code))
    return op_set
Ejemplo n.º 3
0
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
    """Modify model output type."""

    if inference_output_type == dtypes.float32:
        return

    subgraph = model.subgraphs[0]
    tensors = subgraph.tensors
    operators = subgraph.operators

    # Find all dequantize operators
    dequant_opcode_idxs = get_dequantize_opcode_idx(model)
    if operators and not dequant_opcode_idxs:
        for output in subgraph.outputs:
            output_type = _convert_tflite_enum_type_to_tf_type(
                tensors[output].type)
            if output_type == dtypes.float32:
                raise ValueError("Model output is not dequantized.")
        # None of the outputs have float32, then they must be int16, int8, or bool
        return

    # Validate that the model output is dequantized
    output_dequant_ops = []
    for op in operators:
        # Find operators that dequantize model output
        if (op.opcodeIndex in dequant_opcode_idxs
                and op.outputs[0] in subgraph.outputs):
            # If found, validate that the operator's output type is float
            quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[
                op.outputs[0]]
            float_type = _convert_tflite_enum_type_to_tf_type(
                float_tensor.type)
            if float_type != dtypes.float32:
                if float_type == inference_output_type:
                    continue
                else:
                    raise ValueError(
                        "Initial model output type must be tf.float32. Expected type for "
                        "tensor with name '{}' is tf.float32, instead type is {}"
                        .format(float_tensor.name,
                                get_tf_type_name(float_type)))
            # If found, validate that the operator input is quantized and compatible
            # with the final model output type
            quant_type = _convert_tflite_enum_type_to_tf_type(
                quant_tensor.type)
            if quant_type not in _MAP_QUANT_TO_IO_TYPES:
                raise ValueError(
                    "Initial model output is not dequantized. Expected type for "
                    "tensor with name '{}' should be in {}, instead type is {}"
                    .format(
                        quant_tensor.name,
                        tuple(
                            get_tf_type_name(t)
                            for t in _MAP_QUANT_TO_IO_TYPES.keys()),
                        get_tf_type_name(quant_type)))
            else:
                inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
                if inference_output_type not in inference_io_types:
                    raise ValueError(
                        "Unsupported `inference_output_type` value. Expected to be in "
                        "{}, instead got {}.".format(
                            tuple(
                                get_tf_type_name(t)
                                for t in inference_io_types),
                            get_tf_type_name(inference_output_type)))
            output_dequant_ops.append(op)

    if len(subgraph.outputs) != len(output_dequant_ops):
        logging.warning(
            "For model outputs containing unsupported operations which cannot be "
            "quantized, the `inference_output_type` attribute will default to the "
            "original type.")

    # Modify model output type
    if inference_output_type == dtypes.uint8:
        # Find a quantize operator
        quant_opcode_idx = -1
        for idx, opcode in enumerate(model.operatorCodes):
            builtin_code = schema_util.get_builtin_code_from_operator_code(
                opcode)
            if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
                quant_opcode_idx = idx
                break
        # Create a quantize operator, if none exist
        if quant_opcode_idx == -1:
            quant_op = schema_fb.OperatorCodeT()
            quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
            quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
            model.operatorCodes.append(quant_op)
            quant_opcode_idx = len(model.operatorCodes) - 1
        # Change dequant op (int8 to float) to quant op (int8 to uint8)
        for op in output_dequant_ops:
            op.opcodeIndex = quant_opcode_idx
            int8_quantization = tensors[op.inputs[0]].quantization
            uint8_quantization = schema_fb.QuantizationParametersT()
            uint8_quantization.scale = [int8_quantization.scale[0]]
            uint8_quantization.zeroPoint = [
                int8_quantization.zeroPoint[0] + 128
            ]
            tensors[op.outputs[0]].quantization = uint8_quantization
            tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
    elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
        # Remove the outputs and the dequant operator
        remove_tensors_idxs = set()
        for op in output_dequant_ops:
            subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
            if model.signatureDefs:
                signature_def = model.signatureDefs[0]
                for i in range(len(signature_def.outputs)):
                    if signature_def.outputs[i].tensorIndex == op.outputs[0]:
                        signature_def.outputs[i].tensorIndex = op.inputs[0]
            remove_tensors_idxs.add(op.outputs[0])
            operators.remove(op)
        # Remove tensors marked for deletion.
        _remove_tensors_from_model(model, remove_tensors_idxs)
    else:
        raise ValueError(
            "Unsupported `inference_output_type` value {}.".format(
                get_tf_type_name(inference_output_type)))
Ejemplo n.º 4
0
def _modify_model_input_type(model, inference_input_type=dtypes.float32):
  """Modify model input type."""

  if inference_input_type == dtypes.float32:
    return

  subgraph = model.subgraphs[0]
  tensors = subgraph.tensors
  operators = subgraph.operators

  # Find all quantize operators
  quant_opcode_idxs = []
  for idx, opcode in enumerate(model.operatorCodes):
    builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
    if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
      quant_opcode_idxs.append(idx)
  if operators and not quant_opcode_idxs:
    raise ValueError("Model input is not quantized.")

  # Validate that the model input is quantized
  input_quant_ops = []
  for op in operators:
    # Find operators that quantize model input
    if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
      float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
      # If found, validate that the operator's input type is float
      float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
      if float_type != dtypes.float32:
        raise ValueError(
            "Initial model input type must be tf.float32. Expected type for "
            "tensor with name '{}' is tf.float32, instead type is {}".format(
                float_tensor.name, _get_tf_type_name(float_type)))
      # If found, validate that the operator output is quantized and compatible
      # with the final model input type
      quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
      if quant_type not in _MAP_QUANT_TO_IO_TYPES:
        raise ValueError(
            "Initial model input is not quantized. Expected type for "
            "tensor with name '{}' should be in {}, instead type is {}".format(
                quant_tensor.name,
                tuple(_get_tf_type_name(t) for t in
                      _MAP_QUANT_TO_IO_TYPES.keys()),
                _get_tf_type_name(quant_type)))
      else:
        inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
        if inference_input_type not in inference_io_types:
          raise ValueError(
              "Unsupported `inference_input_type` value. Expected to be in "
              "{}, instead got {}.".format(
                  tuple(_get_tf_type_name(t) for t in inference_io_types),
                  _get_tf_type_name(inference_input_type)))
      input_quant_ops.append(op)

  if len(subgraph.inputs) != len(input_quant_ops):
    logging.warning(
        "For model inputs containing unsupported operations which cannot be "
        "quantized, the `inference_input_type` attribute will default to the "
        "original type."
        )

  # Modify model input type
  if inference_input_type == dtypes.uint8:
    # Change quant op (float to int8) to quant op (uint8 to int8)
    for op in input_quant_ops:
      int8_quantization = tensors[op.outputs[0]].quantization
      uint8_quantization = schema_fb.QuantizationParametersT()
      uint8_quantization.scale = [int8_quantization.scale[0]]
      uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
      tensors[op.inputs[0]].quantization = uint8_quantization
      tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
  elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
    # Remove the inputs and the quant operator
    remove_tensors_idxs = set()
    for op in input_quant_ops:
      subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
      remove_tensors_idxs.add(op.inputs[0])
      operators.remove(op)
    # Remove tensors marked for deletion.
    _remove_tensors_from_model(model, remove_tensors_idxs)
  else:
    raise ValueError(
        "Unsupported `inference_input_type` value {}.".format(
            _get_tf_type_name(inference_input_type)))