예제 #1
0
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
    """Modify model output type."""

    if inference_output_type == dtypes.float32:
        return

    subgraph = model.subgraphs[0]
    tensors = subgraph.tensors
    operators = subgraph.operators

    # Find all dequantize operators
    dequant_opcode_idxs = get_dequantize_opcode_idx(model)
    if operators and not dequant_opcode_idxs:
        for output in subgraph.outputs:
            output_type = _convert_tflite_enum_type_to_tf_type(
                tensors[output].type)
            if output_type == dtypes.float32:
                raise ValueError("Model output is not dequantized.")
        # None of the outputs have float32, then they must be int16, int8, or bool
        return

    # Validate that the model output is dequantized
    output_dequant_ops = []
    for op in operators:
        # Find operators that dequantize model output
        if (op.opcodeIndex in dequant_opcode_idxs
                and op.outputs[0] in subgraph.outputs):
            # If found, validate that the operator's output type is float
            quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[
                op.outputs[0]]
            float_type = _convert_tflite_enum_type_to_tf_type(
                float_tensor.type)
            if float_type != dtypes.float32:
                if float_type == inference_output_type:
                    continue
                else:
                    raise ValueError(
                        "Initial model output type must be tf.float32. Expected type for "
                        "tensor with name '{}' is tf.float32, instead type is {}"
                        .format(float_tensor.name,
                                get_tf_type_name(float_type)))
            # If found, validate that the operator input is quantized and compatible
            # with the final model output type
            quant_type = _convert_tflite_enum_type_to_tf_type(
                quant_tensor.type)
            if quant_type not in _MAP_QUANT_TO_IO_TYPES:
                raise ValueError(
                    "Initial model output is not dequantized. Expected type for "
                    "tensor with name '{}' should be in {}, instead type is {}"
                    .format(
                        quant_tensor.name,
                        tuple(
                            get_tf_type_name(t)
                            for t in _MAP_QUANT_TO_IO_TYPES.keys()),
                        get_tf_type_name(quant_type)))
            else:
                inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
                if inference_output_type not in inference_io_types:
                    raise ValueError(
                        "Unsupported `inference_output_type` value. Expected to be in "
                        "{}, instead got {}.".format(
                            tuple(
                                get_tf_type_name(t)
                                for t in inference_io_types),
                            get_tf_type_name(inference_output_type)))
            output_dequant_ops.append(op)

    if len(subgraph.outputs) != len(output_dequant_ops):
        logging.warning(
            "For model outputs containing unsupported operations which cannot be "
            "quantized, the `inference_output_type` attribute will default to the "
            "original type.")

    # Modify model output type
    if inference_output_type == dtypes.uint8:
        # Find a quantize operator
        quant_opcode_idx = -1
        for idx, opcode in enumerate(model.operatorCodes):
            builtin_code = schema_util.get_builtin_code_from_operator_code(
                opcode)
            if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
                quant_opcode_idx = idx
                break
        # Create a quantize operator, if none exist
        if quant_opcode_idx == -1:
            quant_op = schema_fb.OperatorCodeT()
            quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
            quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
            model.operatorCodes.append(quant_op)
            quant_opcode_idx = len(model.operatorCodes) - 1
        # Change dequant op (int8 to float) to quant op (int8 to uint8)
        for op in output_dequant_ops:
            op.opcodeIndex = quant_opcode_idx
            int8_quantization = tensors[op.inputs[0]].quantization
            uint8_quantization = schema_fb.QuantizationParametersT()
            uint8_quantization.scale = [int8_quantization.scale[0]]
            uint8_quantization.zeroPoint = [
                int8_quantization.zeroPoint[0] + 128
            ]
            tensors[op.outputs[0]].quantization = uint8_quantization
            tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
    elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
        # Remove the outputs and the dequant operator
        remove_tensors_idxs = set()
        for op in output_dequant_ops:
            subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
            if model.signatureDefs:
                signature_def = model.signatureDefs[0]
                for i in range(len(signature_def.outputs)):
                    if signature_def.outputs[i].tensorIndex == op.outputs[0]:
                        signature_def.outputs[i].tensorIndex = op.inputs[0]
            remove_tensors_idxs.add(op.outputs[0])
            operators.remove(op)
        # Remove tensors marked for deletion.
        _remove_tensors_from_model(model, remove_tensors_idxs)
    else:
        raise ValueError(
            "Unsupported `inference_output_type` value {}.".format(
                get_tf_type_name(inference_output_type)))
예제 #2
0
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
    """Modify model output type."""

    if inference_output_type == dtypes.float32:
        return

    if inference_output_type not in _TFLITE_MODEL_INPUT_OUTPUT_TYPES:
        raise ValueError(
            "Unsupported `inference_output_type` value. Expected to be in {}, "
            "instead got {}.".format(
                tuple(
                    _get_tf_type_name(t)
                    for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES),
                _get_tf_type_name(inference_output_type)))

    subgraph = model.subgraphs[0]
    tensors = subgraph.tensors
    operators = subgraph.operators

    # Find all dequantize operators
    dequant_opcode_idxs = []
    for idx, opcode in enumerate(model.operatorCodes):
        if opcode.builtinCode == schema_fb.BuiltinOperator.DEQUANTIZE:
            dequant_opcode_idxs.append(idx)
    if not dequant_opcode_idxs:
        raise ValueError("Model output is not dequantized.")

    # Ensure that the model output is dequantized
    output_dequant_ops = []
    for op in operators:
        # Check if the operator dequantizes an output
        if op.opcodeIndex in dequant_opcode_idxs and \
            op.outputs[0] in subgraph.outputs:
            # If found, validate the operator input/output tensor types
            int_tensor, float_tensor = tensors[op.inputs[0]], tensors[
                op.outputs[0]]
            if float_tensor.type != schema_fb.TensorType.FLOAT32:
                raise ValueError(
                    "Model output type must be tf.float32. Expected type for tensor "
                    "with name '{}' is tf.float32, instead type is {}".format(
                        float_tensor.name,
                        _get_tf_type_name(
                            _convert_tflite_enum_type_to_tf_type(
                                float_tensor.type))))
            if int_tensor.type != schema_fb.TensorType.INT8:
                raise ValueError(
                    "Model output is not dequantized. Expected type for tensor "
                    "with name '{}' is tf.int8, instead type is {}".format(
                        int_tensor.name,
                        _get_tf_type_name(
                            _convert_tflite_enum_type_to_tf_type(
                                int_tensor.type))))
            output_dequant_ops.append(op)

    if len(subgraph.outputs) != len(output_dequant_ops):
        raise ValueError("Model output is not dequantized.")

    # Modify model output type
    if inference_output_type == dtypes.uint8:
        # Find a quantize operator
        quant_opcode_idx = -1
        for idx, opcode in enumerate(model.operatorCodes):
            if opcode.builtinCode == schema_fb.BuiltinOperator.QUANTIZE:
                quant_opcode_idx = idx
                break
        # Create a quantize operator, if none exist
        if quant_opcode_idx == -1:
            quant_op = schema_fb.OperatorCodeT()
            quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
            quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
            model.operatorCodes.append(quant_op)
            quant_opcode_idx = len(model.operatorCodes) - 1
        # Change dequant op (int8 to float) to quant op (int8 to uint8)
        for op in output_dequant_ops:
            op.opcodeIndex = quant_opcode_idx
            int8_quantization = tensors[op.inputs[0]].quantization
            uint8_quantization = schema_fb.QuantizationParametersT()
            uint8_quantization.scale = [int8_quantization.scale[0]]
            uint8_quantization.zeroPoint = [
                int8_quantization.zeroPoint[0] + 128
            ]
            tensors[op.outputs[0]].quantization = uint8_quantization
            tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
    elif inference_output_type == dtypes.int8:
        # Remove the outputs and the dequant operator
        remove_tensors_idxs = set()
        for op in output_dequant_ops:
            subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
            remove_tensors_idxs.add(op.outputs[0])
            operators.remove(op)
        # Remove tensors marked for deletion.
        _remove_tensors_from_model(model, remove_tensors_idxs)
    else:
        raise ValueError(
            "Unsupported `inference_output_type` value. Expected to be in {}, "
            "instead got {}.".format(
                tuple(
                    _get_tf_type_name(t)
                    for t in _TFLITE_MODEL_INPUT_OUTPUT_TYPES),
                _get_tf_type_name(inference_output_type)))