示例#1
0
def _static_range_quantize(saved_model_path: str,
                           signature_keys=None,
                           tags=None,
                           output_directory=None,
                           representative_dataset=None):
    """Quantizes the given SavedModel via static range quantization.

  Args:
    saved_model_path: Path to the saved model. When representative_dataset is
      not provided, this should be a model trained with QAT.
    signature_keys: List of keys identifying SignatureDef containing inputs and
      outputs.
    tags: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze.
    output_directory: The path to save the output SavedModel (must be an empty
      directory).
    representative_dataset: a generator that returns a dictionary in
      {input_name: input_tensor} format or a tuple with signature key and a
      dictionary in {input_name: input_tensor} format that feeds calibration
        data for quantizing model. This should be provided when the model is not
        a QAT model.

  Returns:
    A SavedModel object with TF quantization applied.

  Raises:
    ValueError: when representative_dataset is not provided for non-QAT model.
  """
    is_qat_saved_model = _is_qat_saved_model(saved_model_path)
    signatures = _get_signatures_from_saved_model(saved_model_path,
                                                  signature_keys, tags)

    # Checks if the model is from QAT
    if representative_dataset is None and not is_qat_saved_model:
        raise ValueError(
            'When `representative_dataset` is not provided, the model should be '
            'trained with quantization-aware training (QAT).')

    if is_qat_saved_model:
        # Handle QAT models are supported.
        graph_def_serialized = (quantize_model_wrapper.quantize_qat_model(
            saved_model_path, ','.join(signature_keys), ','.join(tags)))
    else:
        # Handle PTQ models are supported with mocking calibration.
        graph_def_serialized = (
            quantize_model_wrapper.quantize_ptq_model_pre_calibration(
                saved_model_path, ','.join(signature_keys), ','.join(tags)))

        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(graph_def_serialized)

        float_model_dir = tempfile.mkdtemp()
        v1_builder = builder.SavedModelBuilder(float_model_dir)

        with session.Session(graph=ops.Graph()) as sess:
            for function_def in graph_def.library.function:
                for node_def in function_def.node_def:
                    if node_def.op == 'CustomAggregator':
                        node_def.attr['id'].s = uuid.uuid4().hex.encode(
                            'ascii')

            importer.import_graph_def(graph_def, name='')
            working_graph = ops.get_default_graph()
            graph_def = working_graph.as_graph_def()

            signatures = _fix_tensor_names(signatures, working_graph)
            if signatures is None:
                raise ValueError(
                    "The input SavedModel doesn't contain a valid signature")

            v1_builder.add_meta_graph_and_variables(
                sess, [tag_constants.SERVING], signature_def_map=signatures)

        v1_builder.save()

        float_model = saved_model_load(float_model_dir)

        for sample in representative_dataset():
            # TODO(b/214311251): Add a test case with multiple signatures.
            if isinstance(sample, tuple):
                if not isinstance(sample[1], dict):
                    raise ValueError(
                        'You need to provide a dictionary with input '
                        'names and values in the second argument in the '
                        'tuple')
                signature_key = sample[0]
                input_data_map = sample[1]
            elif isinstance(sample, dict):
                if len(signature_keys) > 1:
                    raise ValueError(
                        'When the model has multiple signatures, you need '
                        'to provide a tuple with signature key and a '
                        'dictionary with input names and values')
                signature_key = signature_keys[0]
                input_data_map = sample
            else:
                raise ValueError(
                    'You need to provide either a dictionary with input '
                    'names and values or a tuple with signature key and a '
                    'dictionary with input names and values')
            func = float_model.signatures[signature_key]
            func(**input_data_map)

        for function_def in graph_def.library.function:
            for node_def in function_def.node_def:
                if node_def.op == 'CustomAggregator':
                    node_id = node_def.attr['id'].s
                    try:
                        min_val = quantize_model_wrapper.get_min_from_calibrator(
                            node_id)
                        max_val = quantize_model_wrapper.get_max_from_calibrator(
                            node_id)
                        quantize_model_wrapper.clear_data_from_calibrator(
                            node_id)
                        node_def.attr['min'].f = float(min_val)
                        node_def.attr['max'].f = float(max_val)
                    except ValueError:
                        warnings.warn('%s does not have min/max values.' %
                                      node_id)

        calibrated_model_dir = tempfile.mkdtemp()
        v1_builder = builder.SavedModelBuilder(calibrated_model_dir)

        with session.Session(graph=ops.Graph()) as sess:
            importer.import_graph_def(graph_def, name='')
            working_graph = ops.get_default_graph()
            graph_def = working_graph.as_graph_def()

            v1_builder.add_meta_graph_and_variables(
                sess, [tag_constants.SERVING], signature_def_map=signatures)

        v1_builder.save()
        signatures = _get_signatures_from_saved_model(calibrated_model_dir,
                                                      signature_keys, tags)

        graph_def_serialized = (
            quantize_model_wrapper.quantize_ptq_model_post_calibration(
                calibrated_model_dir,
                ','.join(signature_keys),
                ','.join(tags),
            ))

    graph_def = graph_pb2.GraphDef()
    graph_def.ParseFromString(graph_def_serialized)

    if output_directory is None:
        output_directory = tempfile.mkdtemp()
    v1_builder = builder.SavedModelBuilder(output_directory)

    with session.Session(graph=ops.Graph()) as sess:
        importer.import_graph_def(graph_def, name='')
        working_graph = ops.get_default_graph()

        signatures = _fix_tensor_names(signatures, working_graph)
        if signatures is None:
            raise ValueError(
                "The input SavedModel doesn't contain a valid signature")

        v1_builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],
                                                signature_def_map=signatures)

    v1_builder.save()

    return saved_model_load(output_directory)
示例#2
0
def _static_range_quantize(
    saved_model_path: str,
    signature_keys: Sequence[str],
    tags: Collection[str],
    output_directory: str,
    quantization_options: quant_opts_pb2.QuantizationOptions,
    representative_dataset: Optional[
        repr_dataset.RepresentativeDatasetOrMapping] = None
) ->...:
    """Quantizes the given SavedModel via static range quantization.

  Args:
    saved_model_path: Path to the saved model. When representative_dataset is
      not provided, this should be a model trained with QAT.
    signature_keys: Sequence of keys identifying SignatureDef containing inputs
      and outputs.
    tags: Collection of tags identifying the MetaGraphDef within the SavedModel
      to analyze.
    output_directory: The path to save the output SavedModel. The directory will
      be overwritten if not empty.
    quantization_options: QuantizationOptions proto describing quantization
      related config.
    representative_dataset: a generator that returns a dictionary in {input_key:
      input_value} format or a tuple with signature key and a dictionary in
      {input_key: input_value} format that feeds calibration data for quantizing
      model. This should be provided when the model is not a QAT model.

  Returns:
    A SavedModel object with TF quantization applied.

  Raises:
    ValueError: when representative_dataset is not provided for non-QAT model.
    RuntimeError: When a MetaGraphDef could not be found associated with `tags`
      in the SavedModel.
  """
    is_qat_saved_model = _is_qat_saved_model(saved_model_path)
    signatures = _get_signatures_from_saved_model(saved_model_path,
                                                  signature_keys, tags)

    # Checks if the model is from QAT
    if representative_dataset is None and not is_qat_saved_model:
        raise ValueError(
            'When `representative_dataset` is not provided, the model should be '
            'trained with quantization-aware training (QAT).')

    if is_qat_saved_model:
        # Handle QAT models are supported.
        graph_def_serialized = (quantize_model_wrapper.quantize_qat_model(
            saved_model_path, ','.join(signature_keys), ','.join(tags),
            quantization_options.SerializeToString()))
    else:
        # Handle PTQ models are supported with mocking calibration.
        graph_def_serialized = (
            quantize_model_wrapper.quantize_ptq_model_pre_calibration(
                saved_model_path, ','.join(signature_keys), ','.join(tags)))

        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(graph_def_serialized)

        float_model_dir = tempfile.mkdtemp()
        v1_builder = builder.SavedModelBuilder(float_model_dir)

        with session.Session(graph=ops.Graph()) as sess:
            for function_def in graph_def.library.function:
                for node_def in function_def.node_def:
                    if node_def.op == 'CustomAggregator':
                        node_def.attr['id'].s = uuid.uuid4().hex.encode(
                            'ascii')

            importer.import_graph_def(graph_def, name='')
            working_graph = ops.get_default_graph()
            graph_def = working_graph.as_graph_def()

            signatures = _fix_tensor_names(signatures, working_graph)
            if signatures is None:
                raise ValueError(
                    "The input SavedModel doesn't contain a valid signature")

            v1_builder.add_meta_graph_and_variables(
                sess, tags, signature_def_map=signatures)

        v1_builder.save()

        # Uses the representative dataset to collect statistics for calibration.
        # Handles the graph mode execution separately in case TF2 is disabled or
        # eager execution is disabled. The min & max values are stored separately
        # in a global CalibratorSingleton instance.
        _run_graph_for_calibration(float_model_dir, signature_keys, tags,
                                   representative_dataset)

        for function_def in graph_def.library.function:
            for node_def in function_def.node_def:
                if node_def.op == 'CustomAggregator':
                    node_id = node_def.attr['id'].s
                    try:
                        min_val = quantize_model_wrapper.get_min_from_calibrator(
                            node_id)
                        max_val = quantize_model_wrapper.get_max_from_calibrator(
                            node_id)
                        quantize_model_wrapper.clear_data_from_calibrator(
                            node_id)
                        node_def.attr['min'].f = float(min_val)
                        node_def.attr['max'].f = float(max_val)
                    except ValueError:
                        warnings.warn(
                            f'CustomAggregator id "{node_id.decode("utf-8")}" from '
                            f'FunctionDef "{function_def.signature.name}" does not have '
                            'min or max values. This function may not be quantized.'
                        )

        calibrated_model_dir = tempfile.mkdtemp()
        v1_builder = builder.SavedModelBuilder(calibrated_model_dir)

        with session.Session(graph=ops.Graph()) as sess:
            importer.import_graph_def(graph_def, name='')
            working_graph = ops.get_default_graph()
            graph_def = working_graph.as_graph_def()

            v1_builder.add_meta_graph_and_variables(
                sess, tags, signature_def_map=signatures)

        v1_builder.save()
        signatures = _get_signatures_from_saved_model(calibrated_model_dir,
                                                      signature_keys, tags)

        graph_def_serialized = (
            quantize_model_wrapper.quantize_ptq_model_post_calibration(
                calibrated_model_dir, ','.join(signature_keys), ','.join(tags),
                quantization_options.SerializeToString()))

    graph_def = graph_pb2.GraphDef()
    graph_def.ParseFromString(graph_def_serialized)

    _create_empty_output_dir(output_directory)
    v1_builder = builder.SavedModelBuilder(output_directory)

    with session.Session(graph=ops.Graph()) as sess:
        importer.import_graph_def(graph_def, name='')
        working_graph = ops.get_default_graph()

        signatures = _fix_tensor_names(signatures, working_graph)
        if signatures is None:
            raise ValueError(
                "The input SavedModel doesn't contain a valid signature")

        v1_builder.add_meta_graph_and_variables(sess,
                                                tags,
                                                signature_def_map=signatures)

    v1_builder.save()

    return saved_model_load(output_directory)
示例#3
0
def _static_range_quantize(saved_model_path: str,
                           signature_keys=None,
                           tags=None,
                           output_directory=None,
                           representative_dataset=None):
  """Quantizes the given SavedModel via static range quantization.

  Args:
    saved_model_path: Path to the saved model. When representative_dataset is
      not provided, this should be a model trained with QAT.
    signature_keys: List of keys identifying SignatureDef containing inputs and
      outputs.
    tags: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze.
    output_directory: The path to save the output SavedModel (must be an empty
      directory).
    representative_dataset: a generator that returns a dictionary in
      {input_name: input_tensor} format or a tuple with signature key and a
      dictionary in {input_name: input_tensor} format that feeds calibration
      data for quantizing model. This should be provided when the model is not a
      QAT model.

  Returns:
    A SavedModel object with TF quantization applied.

  Raises:
    ValueError: when representative_dataset is not provided for non-QAT model.
  """
  is_qat_saved_model = _is_qat_saved_model(saved_model_path)
  signatures = _get_signatures_from_saved_model(saved_model_path,
                                                signature_keys, tags)

  # Checks if the model is from QAT
  if representative_dataset is None and not is_qat_saved_model:
    raise ValueError(
        'When `representative_dataset` is not provided, the model should be '
        'trained with quantization-aware training (QAT).')

  if is_qat_saved_model:
    # Handle QAT models are supported.
    graph_def_serialized = (
        quantize_model_wrapper.quantize_qat_model(saved_model_path,
                                                  ','.join(signature_keys),
                                                  ','.join(tags)))
  else:
    # Handle PTQ models are supported with mocking calibration.
    graph_def_serialized = (
        quantize_model_wrapper.quantize_ptq_model_pre_calibration(
            saved_model_path, ','.join(signature_keys), ','.join(tags)))

    graph_def = graph_pb2.GraphDef()
    graph_def.ParseFromString(graph_def_serialized)

    float_model_dir = tempfile.mkdtemp()
    v1_builder = builder.SavedModelBuilder(float_model_dir)

    with session.Session(graph=ops.Graph()) as sess:
      for function_def in graph_def.library.function:
        for node_def in function_def.node_def:
          if node_def.op == 'CustomAggregator':
            node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii')

      importer.import_graph_def(graph_def, name='')
      working_graph = ops.get_default_graph()
      graph_def = working_graph.as_graph_def()

      signatures = _fix_tensor_names(signatures, working_graph)
      if signatures is None:
        raise ValueError(
            "The input SavedModel doesn't contain a valid signature")

      v1_builder.add_meta_graph_and_variables(
          sess, [tag_constants.SERVING], signature_def_map=signatures)

    v1_builder.save()

    float_model = saved_model_load(float_model_dir)

    # Uses the representative dataset to collect statistics for calibration.
    # Handles the graph mode execution separately in case TF2 is disabled or
    # eager execution is disabled.
    if context.executing_eagerly():
      _run_graph_for_calibration_eager_mode(float_model, signature_keys,
                                            representative_dataset)
    else:
      _run_graph_for_calibration_graph_mode(float_model, signature_keys,
                                            representative_dataset)

    for function_def in graph_def.library.function:
      for node_def in function_def.node_def:
        if node_def.op == 'CustomAggregator':
          node_id = node_def.attr['id'].s
          try:
            min_val = quantize_model_wrapper.get_min_from_calibrator(node_id)
            max_val = quantize_model_wrapper.get_max_from_calibrator(node_id)
            quantize_model_wrapper.clear_data_from_calibrator(node_id)
            node_def.attr['min'].f = float(min_val)
            node_def.attr['max'].f = float(max_val)
          except ValueError:
            warnings.warn(
                f'CustomAggregator id "{node_id.decode("utf-8")}" from '
                f'FunctionDef "{function_def.signature.name}" does not have '
                'min or max values. This function may not be quantized.')

    calibrated_model_dir = tempfile.mkdtemp()
    v1_builder = builder.SavedModelBuilder(calibrated_model_dir)

    with session.Session(graph=ops.Graph()) as sess:
      importer.import_graph_def(graph_def, name='')
      working_graph = ops.get_default_graph()
      graph_def = working_graph.as_graph_def()

      v1_builder.add_meta_graph_and_variables(
          sess, [tag_constants.SERVING], signature_def_map=signatures)

    v1_builder.save()
    signatures = _get_signatures_from_saved_model(calibrated_model_dir,
                                                  signature_keys, tags)

    graph_def_serialized = (
        quantize_model_wrapper.quantize_ptq_model_post_calibration(
            calibrated_model_dir,
            ','.join(signature_keys),
            ','.join(tags),
        ))

  graph_def = graph_pb2.GraphDef()
  graph_def.ParseFromString(graph_def_serialized)

  if output_directory is None:
    output_directory = tempfile.mkdtemp()
  v1_builder = builder.SavedModelBuilder(output_directory)

  with session.Session(graph=ops.Graph()) as sess:
    importer.import_graph_def(graph_def, name='')
    working_graph = ops.get_default_graph()

    signatures = _fix_tensor_names(signatures, working_graph)
    if signatures is None:
      raise ValueError("The input SavedModel doesn't contain a valid signature")

    v1_builder.add_meta_graph_and_variables(
        sess, [tag_constants.SERVING], signature_def_map=signatures)

  v1_builder.save()

  return saved_model_load(output_directory)