def _static_range_quantize(saved_model_path: str, signature_keys=None, tags=None, output_directory=None, representative_dataset=None): """Quantizes the given SavedModel via static range quantization. Args: saved_model_path: Path to the saved model. When representative_dataset is not provided, this should be a model trained with QAT. signature_keys: List of keys identifying SignatureDef containing inputs and outputs. tags: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. output_directory: The path to save the output SavedModel (must be an empty directory). representative_dataset: a generator that returns a dictionary in {input_name: input_tensor} format or a tuple with signature key and a dictionary in {input_name: input_tensor} format that feeds calibration data for quantizing model. This should be provided when the model is not a QAT model. Returns: A SavedModel object with TF quantization applied. Raises: ValueError: when representative_dataset is not provided for non-QAT model. """ is_qat_saved_model = _is_qat_saved_model(saved_model_path) signatures = _get_signatures_from_saved_model(saved_model_path, signature_keys, tags) # Checks if the model is from QAT if representative_dataset is None and not is_qat_saved_model: raise ValueError( 'When `representative_dataset` is not provided, the model should be ' 'trained with quantization-aware training (QAT).') if is_qat_saved_model: # Handle QAT models are supported. graph_def_serialized = (quantize_model_wrapper.quantize_qat_model( saved_model_path, ','.join(signature_keys), ','.join(tags))) else: # Handle PTQ models are supported with mocking calibration. graph_def_serialized = ( quantize_model_wrapper.quantize_ptq_model_pre_calibration( saved_model_path, ','.join(signature_keys), ','.join(tags))) graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(graph_def_serialized) float_model_dir = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(float_model_dir) with session.Session(graph=ops.Graph()) as sess: for function_def in graph_def.library.function: for node_def in function_def.node_def: if node_def.op == 'CustomAggregator': node_def.attr['id'].s = uuid.uuid4().hex.encode( 'ascii') importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() graph_def = working_graph.as_graph_def() signatures = _fix_tensor_names(signatures, working_graph) if signatures is None: raise ValueError( "The input SavedModel doesn't contain a valid signature") v1_builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=signatures) v1_builder.save() float_model = saved_model_load(float_model_dir) for sample in representative_dataset(): # TODO(b/214311251): Add a test case with multiple signatures. if isinstance(sample, tuple): if not isinstance(sample[1], dict): raise ValueError( 'You need to provide a dictionary with input ' 'names and values in the second argument in the ' 'tuple') signature_key = sample[0] input_data_map = sample[1] elif isinstance(sample, dict): if len(signature_keys) > 1: raise ValueError( 'When the model has multiple signatures, you need ' 'to provide a tuple with signature key and a ' 'dictionary with input names and values') signature_key = signature_keys[0] input_data_map = sample else: raise ValueError( 'You need to provide either a dictionary with input ' 'names and values or a tuple with signature key and a ' 'dictionary with input names and values') func = float_model.signatures[signature_key] func(**input_data_map) for function_def in graph_def.library.function: for node_def in function_def.node_def: if node_def.op == 'CustomAggregator': node_id = node_def.attr['id'].s try: min_val = quantize_model_wrapper.get_min_from_calibrator( node_id) max_val = quantize_model_wrapper.get_max_from_calibrator( node_id) quantize_model_wrapper.clear_data_from_calibrator( node_id) node_def.attr['min'].f = float(min_val) node_def.attr['max'].f = float(max_val) except ValueError: warnings.warn('%s does not have min/max values.' % node_id) calibrated_model_dir = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(calibrated_model_dir) with session.Session(graph=ops.Graph()) as sess: importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() graph_def = working_graph.as_graph_def() v1_builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=signatures) v1_builder.save() signatures = _get_signatures_from_saved_model(calibrated_model_dir, signature_keys, tags) graph_def_serialized = ( quantize_model_wrapper.quantize_ptq_model_post_calibration( calibrated_model_dir, ','.join(signature_keys), ','.join(tags), )) graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(graph_def_serialized) if output_directory is None: output_directory = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(output_directory) with session.Session(graph=ops.Graph()) as sess: importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() signatures = _fix_tensor_names(signatures, working_graph) if signatures is None: raise ValueError( "The input SavedModel doesn't contain a valid signature") v1_builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], signature_def_map=signatures) v1_builder.save() return saved_model_load(output_directory)
def _static_range_quantize( saved_model_path: str, signature_keys: Sequence[str], tags: Collection[str], output_directory: str, quantization_options: quant_opts_pb2.QuantizationOptions, representative_dataset: Optional[ repr_dataset.RepresentativeDatasetOrMapping] = None ) ->...: """Quantizes the given SavedModel via static range quantization. Args: saved_model_path: Path to the saved model. When representative_dataset is not provided, this should be a model trained with QAT. signature_keys: Sequence of keys identifying SignatureDef containing inputs and outputs. tags: Collection of tags identifying the MetaGraphDef within the SavedModel to analyze. output_directory: The path to save the output SavedModel. The directory will be overwritten if not empty. quantization_options: QuantizationOptions proto describing quantization related config. representative_dataset: a generator that returns a dictionary in {input_key: input_value} format or a tuple with signature key and a dictionary in {input_key: input_value} format that feeds calibration data for quantizing model. This should be provided when the model is not a QAT model. Returns: A SavedModel object with TF quantization applied. Raises: ValueError: when representative_dataset is not provided for non-QAT model. RuntimeError: When a MetaGraphDef could not be found associated with `tags` in the SavedModel. """ is_qat_saved_model = _is_qat_saved_model(saved_model_path) signatures = _get_signatures_from_saved_model(saved_model_path, signature_keys, tags) # Checks if the model is from QAT if representative_dataset is None and not is_qat_saved_model: raise ValueError( 'When `representative_dataset` is not provided, the model should be ' 'trained with quantization-aware training (QAT).') if is_qat_saved_model: # Handle QAT models are supported. graph_def_serialized = (quantize_model_wrapper.quantize_qat_model( saved_model_path, ','.join(signature_keys), ','.join(tags), quantization_options.SerializeToString())) else: # Handle PTQ models are supported with mocking calibration. graph_def_serialized = ( quantize_model_wrapper.quantize_ptq_model_pre_calibration( saved_model_path, ','.join(signature_keys), ','.join(tags))) graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(graph_def_serialized) float_model_dir = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(float_model_dir) with session.Session(graph=ops.Graph()) as sess: for function_def in graph_def.library.function: for node_def in function_def.node_def: if node_def.op == 'CustomAggregator': node_def.attr['id'].s = uuid.uuid4().hex.encode( 'ascii') importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() graph_def = working_graph.as_graph_def() signatures = _fix_tensor_names(signatures, working_graph) if signatures is None: raise ValueError( "The input SavedModel doesn't contain a valid signature") v1_builder.add_meta_graph_and_variables( sess, tags, signature_def_map=signatures) v1_builder.save() # Uses the representative dataset to collect statistics for calibration. # Handles the graph mode execution separately in case TF2 is disabled or # eager execution is disabled. The min & max values are stored separately # in a global CalibratorSingleton instance. _run_graph_for_calibration(float_model_dir, signature_keys, tags, representative_dataset) for function_def in graph_def.library.function: for node_def in function_def.node_def: if node_def.op == 'CustomAggregator': node_id = node_def.attr['id'].s try: min_val = quantize_model_wrapper.get_min_from_calibrator( node_id) max_val = quantize_model_wrapper.get_max_from_calibrator( node_id) quantize_model_wrapper.clear_data_from_calibrator( node_id) node_def.attr['min'].f = float(min_val) node_def.attr['max'].f = float(max_val) except ValueError: warnings.warn( f'CustomAggregator id "{node_id.decode("utf-8")}" from ' f'FunctionDef "{function_def.signature.name}" does not have ' 'min or max values. This function may not be quantized.' ) calibrated_model_dir = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(calibrated_model_dir) with session.Session(graph=ops.Graph()) as sess: importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() graph_def = working_graph.as_graph_def() v1_builder.add_meta_graph_and_variables( sess, tags, signature_def_map=signatures) v1_builder.save() signatures = _get_signatures_from_saved_model(calibrated_model_dir, signature_keys, tags) graph_def_serialized = ( quantize_model_wrapper.quantize_ptq_model_post_calibration( calibrated_model_dir, ','.join(signature_keys), ','.join(tags), quantization_options.SerializeToString())) graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(graph_def_serialized) _create_empty_output_dir(output_directory) v1_builder = builder.SavedModelBuilder(output_directory) with session.Session(graph=ops.Graph()) as sess: importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() signatures = _fix_tensor_names(signatures, working_graph) if signatures is None: raise ValueError( "The input SavedModel doesn't contain a valid signature") v1_builder.add_meta_graph_and_variables(sess, tags, signature_def_map=signatures) v1_builder.save() return saved_model_load(output_directory)
def _static_range_quantize(saved_model_path: str, signature_keys=None, tags=None, output_directory=None, representative_dataset=None): """Quantizes the given SavedModel via static range quantization. Args: saved_model_path: Path to the saved model. When representative_dataset is not provided, this should be a model trained with QAT. signature_keys: List of keys identifying SignatureDef containing inputs and outputs. tags: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. output_directory: The path to save the output SavedModel (must be an empty directory). representative_dataset: a generator that returns a dictionary in {input_name: input_tensor} format or a tuple with signature key and a dictionary in {input_name: input_tensor} format that feeds calibration data for quantizing model. This should be provided when the model is not a QAT model. Returns: A SavedModel object with TF quantization applied. Raises: ValueError: when representative_dataset is not provided for non-QAT model. """ is_qat_saved_model = _is_qat_saved_model(saved_model_path) signatures = _get_signatures_from_saved_model(saved_model_path, signature_keys, tags) # Checks if the model is from QAT if representative_dataset is None and not is_qat_saved_model: raise ValueError( 'When `representative_dataset` is not provided, the model should be ' 'trained with quantization-aware training (QAT).') if is_qat_saved_model: # Handle QAT models are supported. graph_def_serialized = ( quantize_model_wrapper.quantize_qat_model(saved_model_path, ','.join(signature_keys), ','.join(tags))) else: # Handle PTQ models are supported with mocking calibration. graph_def_serialized = ( quantize_model_wrapper.quantize_ptq_model_pre_calibration( saved_model_path, ','.join(signature_keys), ','.join(tags))) graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(graph_def_serialized) float_model_dir = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(float_model_dir) with session.Session(graph=ops.Graph()) as sess: for function_def in graph_def.library.function: for node_def in function_def.node_def: if node_def.op == 'CustomAggregator': node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii') importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() graph_def = working_graph.as_graph_def() signatures = _fix_tensor_names(signatures, working_graph) if signatures is None: raise ValueError( "The input SavedModel doesn't contain a valid signature") v1_builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=signatures) v1_builder.save() float_model = saved_model_load(float_model_dir) # Uses the representative dataset to collect statistics for calibration. # Handles the graph mode execution separately in case TF2 is disabled or # eager execution is disabled. if context.executing_eagerly(): _run_graph_for_calibration_eager_mode(float_model, signature_keys, representative_dataset) else: _run_graph_for_calibration_graph_mode(float_model, signature_keys, representative_dataset) for function_def in graph_def.library.function: for node_def in function_def.node_def: if node_def.op == 'CustomAggregator': node_id = node_def.attr['id'].s try: min_val = quantize_model_wrapper.get_min_from_calibrator(node_id) max_val = quantize_model_wrapper.get_max_from_calibrator(node_id) quantize_model_wrapper.clear_data_from_calibrator(node_id) node_def.attr['min'].f = float(min_val) node_def.attr['max'].f = float(max_val) except ValueError: warnings.warn( f'CustomAggregator id "{node_id.decode("utf-8")}" from ' f'FunctionDef "{function_def.signature.name}" does not have ' 'min or max values. This function may not be quantized.') calibrated_model_dir = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(calibrated_model_dir) with session.Session(graph=ops.Graph()) as sess: importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() graph_def = working_graph.as_graph_def() v1_builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=signatures) v1_builder.save() signatures = _get_signatures_from_saved_model(calibrated_model_dir, signature_keys, tags) graph_def_serialized = ( quantize_model_wrapper.quantize_ptq_model_post_calibration( calibrated_model_dir, ','.join(signature_keys), ','.join(tags), )) graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(graph_def_serialized) if output_directory is None: output_directory = tempfile.mkdtemp() v1_builder = builder.SavedModelBuilder(output_directory) with session.Session(graph=ops.Graph()) as sess: importer.import_graph_def(graph_def, name='') working_graph = ops.get_default_graph() signatures = _fix_tensor_names(signatures, working_graph) if signatures is None: raise ValueError("The input SavedModel doesn't contain a valid signature") v1_builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=signatures) v1_builder.save() return saved_model_load(output_directory)