def get_dequantize_opcode_idx(model): """Returns the quantize op idx.""" quant_opcode_idxs = [] for idx, opcode in enumerate(model.operatorCodes): builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE: quant_opcode_idxs.append(idx) return quant_opcode_idxs
def get_ops_list(model_data): """Return a set of ops in the tflite model data.""" model = schema_fb.Model.GetRootAsModel(model_data, 0) op_set = set() for subgraph_idx in range(model.SubgraphsLength()): subgraph = model.Subgraphs(subgraph_idx) for op_idx in range(subgraph.OperatorsLength()): op = subgraph.Operators(op_idx) opcode = model.OperatorCodes(op.OpcodeIndex()) builtin_code = schema_util.get_builtin_code_from_operator_code( opcode) if builtin_code == schema_fb.BuiltinOperator.CUSTOM: opname = opcode.CustomCode().decode("utf-8") op_set.add(opname) else: op_set.add(visualize.BuiltinCodeToName(builtin_code)) return op_set
def _modify_model_output_type(model, inference_output_type=dtypes.float32): """Modify model output type.""" if inference_output_type == dtypes.float32: return subgraph = model.subgraphs[0] tensors = subgraph.tensors operators = subgraph.operators # Find all dequantize operators dequant_opcode_idxs = get_dequantize_opcode_idx(model) if operators and not dequant_opcode_idxs: for output in subgraph.outputs: output_type = _convert_tflite_enum_type_to_tf_type( tensors[output].type) if output_type == dtypes.float32: raise ValueError("Model output is not dequantized.") # None of the outputs have float32, then they must be int16, int8, or bool return # Validate that the model output is dequantized output_dequant_ops = [] for op in operators: # Find operators that dequantize model output if (op.opcodeIndex in dequant_opcode_idxs and op.outputs[0] in subgraph.outputs): # If found, validate that the operator's output type is float quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[ op.outputs[0]] float_type = _convert_tflite_enum_type_to_tf_type( float_tensor.type) if float_type != dtypes.float32: if float_type == inference_output_type: continue else: raise ValueError( "Initial model output type must be tf.float32. Expected type for " "tensor with name '{}' is tf.float32, instead type is {}" .format(float_tensor.name, get_tf_type_name(float_type))) # If found, validate that the operator input is quantized and compatible # with the final model output type quant_type = _convert_tflite_enum_type_to_tf_type( quant_tensor.type) if quant_type not in _MAP_QUANT_TO_IO_TYPES: raise ValueError( "Initial model output is not dequantized. Expected type for " "tensor with name '{}' should be in {}, instead type is {}" .format( quant_tensor.name, tuple( get_tf_type_name(t) for t in _MAP_QUANT_TO_IO_TYPES.keys()), get_tf_type_name(quant_type))) else: inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] if inference_output_type not in inference_io_types: raise ValueError( "Unsupported `inference_output_type` value. Expected to be in " "{}, instead got {}.".format( tuple( get_tf_type_name(t) for t in inference_io_types), get_tf_type_name(inference_output_type))) output_dequant_ops.append(op) if len(subgraph.outputs) != len(output_dequant_ops): logging.warning( "For model outputs containing unsupported operations which cannot be " "quantized, the `inference_output_type` attribute will default to the " "original type.") # Modify model output type if inference_output_type == dtypes.uint8: # Find a quantize operator quant_opcode_idx = -1 for idx, opcode in enumerate(model.operatorCodes): builtin_code = schema_util.get_builtin_code_from_operator_code( opcode) if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: quant_opcode_idx = idx break # Create a quantize operator, if none exist if quant_opcode_idx == -1: quant_op = schema_fb.OperatorCodeT() quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE model.operatorCodes.append(quant_op) quant_opcode_idx = len(model.operatorCodes) - 1 # Change dequant op (int8 to float) to quant op (int8 to uint8) for op in output_dequant_ops: op.opcodeIndex = quant_opcode_idx int8_quantization = tensors[op.inputs[0]].quantization uint8_quantization = schema_fb.QuantizationParametersT() uint8_quantization.scale = [int8_quantization.scale[0]] uint8_quantization.zeroPoint = [ int8_quantization.zeroPoint[0] + 128 ] tensors[op.outputs[0]].quantization = uint8_quantization tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8 elif inference_output_type in _MAP_QUANT_TO_IO_TYPES: # Remove the outputs and the dequant operator remove_tensors_idxs = set() for op in output_dequant_ops: subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0] if model.signatureDefs: signature_def = model.signatureDefs[0] for i in range(len(signature_def.outputs)): if signature_def.outputs[i].tensorIndex == op.outputs[0]: signature_def.outputs[i].tensorIndex = op.inputs[0] remove_tensors_idxs.add(op.outputs[0]) operators.remove(op) # Remove tensors marked for deletion. _remove_tensors_from_model(model, remove_tensors_idxs) else: raise ValueError( "Unsupported `inference_output_type` value {}.".format( get_tf_type_name(inference_output_type)))
def _modify_model_input_type(model, inference_input_type=dtypes.float32): """Modify model input type.""" if inference_input_type == dtypes.float32: return subgraph = model.subgraphs[0] tensors = subgraph.tensors operators = subgraph.operators # Find all quantize operators quant_opcode_idxs = [] for idx, opcode in enumerate(model.operatorCodes): builtin_code = schema_util.get_builtin_code_from_operator_code(opcode) if builtin_code == schema_fb.BuiltinOperator.QUANTIZE: quant_opcode_idxs.append(idx) if operators and not quant_opcode_idxs: raise ValueError("Model input is not quantized.") # Validate that the model input is quantized input_quant_ops = [] for op in operators: # Find operators that quantize model input if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs: float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]] # If found, validate that the operator's input type is float float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type) if float_type != dtypes.float32: raise ValueError( "Initial model input type must be tf.float32. Expected type for " "tensor with name '{}' is tf.float32, instead type is {}".format( float_tensor.name, _get_tf_type_name(float_type))) # If found, validate that the operator output is quantized and compatible # with the final model input type quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type) if quant_type not in _MAP_QUANT_TO_IO_TYPES: raise ValueError( "Initial model input is not quantized. Expected type for " "tensor with name '{}' should be in {}, instead type is {}".format( quant_tensor.name, tuple(_get_tf_type_name(t) for t in _MAP_QUANT_TO_IO_TYPES.keys()), _get_tf_type_name(quant_type))) else: inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type] if inference_input_type not in inference_io_types: raise ValueError( "Unsupported `inference_input_type` value. Expected to be in " "{}, instead got {}.".format( tuple(_get_tf_type_name(t) for t in inference_io_types), _get_tf_type_name(inference_input_type))) input_quant_ops.append(op) if len(subgraph.inputs) != len(input_quant_ops): logging.warning( "For model inputs containing unsupported operations which cannot be " "quantized, the `inference_input_type` attribute will default to the " "original type." ) # Modify model input type if inference_input_type == dtypes.uint8: # Change quant op (float to int8) to quant op (uint8 to int8) for op in input_quant_ops: int8_quantization = tensors[op.outputs[0]].quantization uint8_quantization = schema_fb.QuantizationParametersT() uint8_quantization.scale = [int8_quantization.scale[0]] uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128] tensors[op.inputs[0]].quantization = uint8_quantization tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8 elif inference_input_type in _MAP_QUANT_TO_IO_TYPES: # Remove the inputs and the quant operator remove_tensors_idxs = set() for op in input_quant_ops: subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0] remove_tensors_idxs.add(op.inputs[0]) operators.remove(op) # Remove tensors marked for deletion. _remove_tensors_from_model(model, remove_tensors_idxs) else: raise ValueError( "Unsupported `inference_input_type` value {}.".format( _get_tf_type_name(inference_input_type)))