예제 #1
0
def _run_graph_for_calibration_eager_mode(
    model_dir: str,
    tags: Collection[str],
    representative_dataset_map: repr_dataset.RepresentativeDatasetMapping,
) -> None:
    """Runs the graph for calibration in eager mode.

  This function assumes _eager mode_ (enabled in TF2 by default) when running
  the graph. This step is used in order to collect the statistics in
  CustomAggregatorOp for quantization using the representative dataset for the
  actual data provided for inference.

  Args:
    model_dir: Path to SavedModel directory.
    tags: Collection of tags identifying the MetaGraphDef within the SavedModel.
    representative_dataset_map: A map where signature keys are mapped to
      corresponding representative datasets.

  Raises:
    ValueError: When running the function with the representative dataset fails.
  """
    root: autotrackable.AutoTrackable = saved_model_load(model_dir, tags)
    for signature_key, repr_ds in representative_dataset_map.items():
        try:
            _run_function_for_calibration_eager_mode(
                func=root.signatures[signature_key],
                representative_dataset=repr_ds)
        except Exception as ex:
            raise ValueError(
                'Failed to run representative dataset through the '
                f'function with the signature key: {signature_key}.') from ex
예제 #2
0
def _run_graph_for_calibration_eager_mode(
        model_dir: str, signature_keys: List[str], tags: Set[str],
        representative_dataset: repr_dataset.RepresentativeDataset) -> None:
    """Runs the graph for calibration in eager mode.

  This function assumes _eager mode_ (enabled in TF2 by default) when running
  the graph. This step is used in order to collect the statistics in
  CustomAggregatorOp for quantization using the representative dataset for the
  actual data provided for inference.

  Args:
    model_dir: Path to SavedModel directory.
    signature_keys: A list of signature keys that identifies a function to run
      the data samples with.
    tags: Set of tags identifying the MetaGraphDef within the SavedModel.
    representative_dataset: Representative dataset used for calibration.

  Raises:
    ValueError: When the samples in representative dataset is invalid.
  """
    root: autotrackable.AutoTrackable = saved_model_load(model_dir, tags)
    for sample in representative_dataset:
        signature_key, input_data = _get_signature_key_and_input(
            sample, signature_keys)

        func = root.signatures[signature_key]
        try:
            func(**input_data)
        except Exception as ex:
            raise ValueError(
                f'Failed to run the function with signature key: {signature_key}'
            ) from ex
예제 #3
0
  def _GetFunc(self, use_trt, model_dir, use_dynamic_shape):
    """Gets the mnist function.

    Args:
      use_trt: whether use TF-TRT to convert the graph.
      model_dir: the model directory to load the checkpoints.
      use_dynamic_shape: whether to run the TF-TRT conversion in dynamic shape
        mode.

    Returns:
      The mnist model function.
    """
    with tempfile.TemporaryDirectory() as tmpdir:
      saved_model_dir = os.path.join(tmpdir, 'mnist')
      self._SaveModel(model_dir, saved_model_dir)

      if use_trt:
        conv_params = trt_convert.TrtConversionParams(
            precision_mode='FP16',
            minimum_segment_size=2,
            max_workspace_size_bytes=1 << 28,
            maximum_cached_engines=1)
        converter = trt_convert.TrtGraphConverterV2(
            input_saved_model_dir=saved_model_dir,
            conversion_params=conv_params,
            use_dynamic_shape=use_dynamic_shape,
            dynamic_shape_profile_strategy='ImplicitBatchModeCompatible')
        converter.convert()
        func = converter._converted_func
      else:
        saved_model_loaded = saved_model_load(
            saved_model_dir, tags=[tag_constants.SERVING])
        func = saved_model_loaded.signatures[
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    return func
예제 #4
0
def _dynamic_range_quantize(
        saved_model_path: str, signature_keys: Sequence[str],
        tags: Collection[str], output_directory: str,
        quantization_options: quant_opts_pb2.QuantizationOptions) ->...:
    """Quantizes the given SavedModel via post-training dynamic range quantization.

  Args:
    saved_model_path: Path to the saved model.
    signature_keys: Sequence of keys identifying SignatureDef containing inputs
      and outputs.
    tags: Collection of tags identifying the MetaGraphDef within the SavedModel
      to analyze.
    output_directory: The path to save the output SavedModel. The directory will
      be overwritten if not empty.
    quantization_options: QuantizationOptions proto describing quantization
      related config.

  Returns:
    A SavedModel object with TF quantization applied.

  Raises:
    ValueError: when the model is QAT model.
  """
    is_qat_saved_model = _is_qat_saved_model(saved_model_path)
    signatures = _get_signatures_from_saved_model(saved_model_path,
                                                  signature_keys, tags)

    # Checks if the model is from QAT.
    if is_qat_saved_model:
        raise ValueError(
            'The models trained with quantization-aware training (QAT) is not '
            'supported.')

    # Apply post-training dynamic range quantization to the model.
    graph_def_serialized = (quantize_model_wrapper.quantize_ptq_dynamic_range(
        saved_model_path, ','.join(signature_keys), ','.join(tags),
        quantization_options.SerializeToString()))

    graph_def = graph_pb2.GraphDef()
    graph_def.ParseFromString(graph_def_serialized)

    _create_empty_output_dir(output_directory)
    v1_builder = builder.SavedModelBuilder(output_directory)

    with session.Session(graph=ops.Graph()) as sess:
        importer.import_graph_def(graph_def, name='')
        working_graph = ops.get_default_graph()

        signatures = _fix_tensor_names(signatures, working_graph)
        if signatures is None:
            raise ValueError(
                "The input SavedModel doesn't contain a valid signature")

        v1_builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],
                                                signature_def_map=signatures)

    v1_builder.save()

    return saved_model_load(output_directory)
예제 #5
0
  def _GetFunc(self, use_trt, model_dir, use_dynamic_shape):
    """Gets the mnist function.

    Args:
      use_trt: whether use TF-TRT to convert the graph.
      model_dir: the model directory to load the checkpoints.
      use_dynamic_shape: whether to run the TF-TRT conversion in dynamic shape
        mode.

    Returns:
      The mnist model function.
    """
    with tempfile.TemporaryDirectory() as tmpdir:
      saved_model_dir = os.path.join(tmpdir, 'mnist')
      self._SaveModel(model_dir, saved_model_dir)

      if use_trt:
        conv_params = trt_convert.TrtConversionParams(
            precision_mode='FP16',
            minimum_segment_size=2,
            max_workspace_size_bytes=(
                trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES),
            maximum_cached_engines=1)
        converter = trt_convert.TrtGraphConverterV2(
            input_saved_model_dir=saved_model_dir,
            use_dynamic_shape=use_dynamic_shape,
            dynamic_shape_profile_strategy='ImplicitBatchModeCompatible',
            **conv_params._asdict())
        converter.convert()
        try:
          line_length = max(160, os.get_terminal_size().columns)
        except OSError:
          line_length = 160
        converter.summary(line_length=line_length, detailed=True)
        func = converter._converted_func
      else:
        saved_model_loaded = saved_model_load(
            saved_model_dir, tags=[tag_constants.SERVING])
        func = saved_model_loaded.signatures[
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    return func
예제 #6
0
def _static_range_quantize(
    saved_model_path: str,
    signature_keys: Sequence[str],
    tags: Collection[str],
    output_directory: str,
    quantization_options: quant_opts_pb2.QuantizationOptions,
    representative_dataset: Optional[
        repr_dataset.RepresentativeDatasetOrMapping] = None
) ->...:
    """Quantizes the given SavedModel via static range quantization.

  Args:
    saved_model_path: Path to the saved model. When representative_dataset is
      not provided, this should be a model trained with QAT.
    signature_keys: Sequence of keys identifying SignatureDef containing inputs
      and outputs.
    tags: Collection of tags identifying the MetaGraphDef within the SavedModel
      to analyze.
    output_directory: The path to save the output SavedModel. The directory will
      be overwritten if not empty.
    quantization_options: QuantizationOptions proto describing quantization
      related config.
    representative_dataset: a generator that returns a dictionary in {input_key:
      input_value} format or a tuple with signature key and a dictionary in
      {input_key: input_value} format that feeds calibration data for quantizing
      model. This should be provided when the model is not a QAT model.

  Returns:
    A SavedModel object with TF quantization applied.

  Raises:
    ValueError: when representative_dataset is not provided for non-QAT model.
    RuntimeError: When a MetaGraphDef could not be found associated with `tags`
      in the SavedModel.
  """
    is_qat_saved_model = _is_qat_saved_model(saved_model_path)
    signatures = _get_signatures_from_saved_model(saved_model_path,
                                                  signature_keys, tags)

    # Checks if the model is from QAT
    if representative_dataset is None and not is_qat_saved_model:
        raise ValueError(
            'When `representative_dataset` is not provided, the model should be '
            'trained with quantization-aware training (QAT).')

    if is_qat_saved_model:
        # Handle QAT models are supported.
        graph_def_serialized = (quantize_model_wrapper.quantize_qat_model(
            saved_model_path, ','.join(signature_keys), ','.join(tags),
            quantization_options.SerializeToString()))
    else:
        # Handle PTQ models are supported with mocking calibration.
        graph_def_serialized = (
            quantize_model_wrapper.quantize_ptq_model_pre_calibration(
                saved_model_path, ','.join(signature_keys), ','.join(tags)))

        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(graph_def_serialized)

        float_model_dir = tempfile.mkdtemp()
        v1_builder = builder.SavedModelBuilder(float_model_dir)

        with session.Session(graph=ops.Graph()) as sess:
            for function_def in graph_def.library.function:
                for node_def in function_def.node_def:
                    if node_def.op == 'CustomAggregator':
                        node_def.attr['id'].s = uuid.uuid4().hex.encode(
                            'ascii')

            importer.import_graph_def(graph_def, name='')
            working_graph = ops.get_default_graph()
            graph_def = working_graph.as_graph_def()

            signatures = _fix_tensor_names(signatures, working_graph)
            if signatures is None:
                raise ValueError(
                    "The input SavedModel doesn't contain a valid signature")

            v1_builder.add_meta_graph_and_variables(
                sess, tags, signature_def_map=signatures)

        v1_builder.save()

        # Uses the representative dataset to collect statistics for calibration.
        # Handles the graph mode execution separately in case TF2 is disabled or
        # eager execution is disabled. The min & max values are stored separately
        # in a global CalibratorSingleton instance.
        _run_graph_for_calibration(float_model_dir, signature_keys, tags,
                                   representative_dataset)

        for function_def in graph_def.library.function:
            for node_def in function_def.node_def:
                if node_def.op == 'CustomAggregator':
                    node_id = node_def.attr['id'].s
                    try:
                        min_val = quantize_model_wrapper.get_min_from_calibrator(
                            node_id)
                        max_val = quantize_model_wrapper.get_max_from_calibrator(
                            node_id)
                        quantize_model_wrapper.clear_data_from_calibrator(
                            node_id)
                        node_def.attr['min'].f = float(min_val)
                        node_def.attr['max'].f = float(max_val)
                    except ValueError:
                        warnings.warn(
                            f'CustomAggregator id "{node_id.decode("utf-8")}" from '
                            f'FunctionDef "{function_def.signature.name}" does not have '
                            'min or max values. This function may not be quantized.'
                        )

        calibrated_model_dir = tempfile.mkdtemp()
        v1_builder = builder.SavedModelBuilder(calibrated_model_dir)

        with session.Session(graph=ops.Graph()) as sess:
            importer.import_graph_def(graph_def, name='')
            working_graph = ops.get_default_graph()
            graph_def = working_graph.as_graph_def()

            v1_builder.add_meta_graph_and_variables(
                sess, tags, signature_def_map=signatures)

        v1_builder.save()
        signatures = _get_signatures_from_saved_model(calibrated_model_dir,
                                                      signature_keys, tags)

        graph_def_serialized = (
            quantize_model_wrapper.quantize_ptq_model_post_calibration(
                calibrated_model_dir, ','.join(signature_keys), ','.join(tags),
                quantization_options.SerializeToString()))

    graph_def = graph_pb2.GraphDef()
    graph_def.ParseFromString(graph_def_serialized)

    _create_empty_output_dir(output_directory)
    v1_builder = builder.SavedModelBuilder(output_directory)

    with session.Session(graph=ops.Graph()) as sess:
        importer.import_graph_def(graph_def, name='')
        working_graph = ops.get_default_graph()

        signatures = _fix_tensor_names(signatures, working_graph)
        if signatures is None:
            raise ValueError(
                "The input SavedModel doesn't contain a valid signature")

        v1_builder.add_meta_graph_and_variables(sess,
                                                tags,
                                                signature_def_map=signatures)

    v1_builder.save()

    return saved_model_load(output_directory)
예제 #7
0
def _static_range_quantize(saved_model_path: str,
                           signature_keys=None,
                           tags=None,
                           output_directory=None,
                           representative_dataset=None):
    """Quantizes the given SavedModel via static range quantization.

  Args:
    saved_model_path: Path to the saved model. When representative_dataset is
      not provided, this should be a model trained with QAT.
    signature_keys: List of keys identifying SignatureDef containing inputs and
      outputs.
    tags: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze.
    output_directory: The path to save the output SavedModel (must be an empty
      directory).
    representative_dataset: a generator that returns a dictionary in
      {input_name: input_tensor} format or a tuple with signature key and a
      dictionary in {input_name: input_tensor} format that feeds calibration
        data for quantizing model. This should be provided when the model is not
        a QAT model.

  Returns:
    A SavedModel object with TF quantization applied.

  Raises:
    ValueError: when representative_dataset is not provided for non-QAT model.
  """
    is_qat_saved_model = _is_qat_saved_model(saved_model_path)
    signatures = _get_signatures_from_saved_model(saved_model_path,
                                                  signature_keys, tags)

    # Checks if the model is from QAT
    if representative_dataset is None and not is_qat_saved_model:
        raise ValueError(
            'When `representative_dataset` is not provided, the model should be '
            'trained with quantization-aware training (QAT).')

    if is_qat_saved_model:
        # Handle QAT models are supported.
        graph_def_serialized = (quantize_model_wrapper.quantize_qat_model(
            saved_model_path, ','.join(signature_keys), ','.join(tags)))
    else:
        # Handle PTQ models are supported with mocking calibration.
        graph_def_serialized = (
            quantize_model_wrapper.quantize_ptq_model_pre_calibration(
                saved_model_path, ','.join(signature_keys), ','.join(tags)))

        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(graph_def_serialized)

        float_model_dir = tempfile.mkdtemp()
        v1_builder = builder.SavedModelBuilder(float_model_dir)

        with session.Session(graph=ops.Graph()) as sess:
            for function_def in graph_def.library.function:
                for node_def in function_def.node_def:
                    if node_def.op == 'CustomAggregator':
                        node_def.attr['id'].s = uuid.uuid4().hex.encode(
                            'ascii')

            importer.import_graph_def(graph_def, name='')
            working_graph = ops.get_default_graph()
            graph_def = working_graph.as_graph_def()

            signatures = _fix_tensor_names(signatures, working_graph)
            if signatures is None:
                raise ValueError(
                    "The input SavedModel doesn't contain a valid signature")

            v1_builder.add_meta_graph_and_variables(
                sess, [tag_constants.SERVING], signature_def_map=signatures)

        v1_builder.save()

        float_model = saved_model_load(float_model_dir)

        for sample in representative_dataset():
            # TODO(b/214311251): Add a test case with multiple signatures.
            if isinstance(sample, tuple):
                if not isinstance(sample[1], dict):
                    raise ValueError(
                        'You need to provide a dictionary with input '
                        'names and values in the second argument in the '
                        'tuple')
                signature_key = sample[0]
                input_data_map = sample[1]
            elif isinstance(sample, dict):
                if len(signature_keys) > 1:
                    raise ValueError(
                        'When the model has multiple signatures, you need '
                        'to provide a tuple with signature key and a '
                        'dictionary with input names and values')
                signature_key = signature_keys[0]
                input_data_map = sample
            else:
                raise ValueError(
                    'You need to provide either a dictionary with input '
                    'names and values or a tuple with signature key and a '
                    'dictionary with input names and values')
            func = float_model.signatures[signature_key]
            func(**input_data_map)

        for function_def in graph_def.library.function:
            for node_def in function_def.node_def:
                if node_def.op == 'CustomAggregator':
                    node_id = node_def.attr['id'].s
                    try:
                        min_val = quantize_model_wrapper.get_min_from_calibrator(
                            node_id)
                        max_val = quantize_model_wrapper.get_max_from_calibrator(
                            node_id)
                        quantize_model_wrapper.clear_data_from_calibrator(
                            node_id)
                        node_def.attr['min'].f = float(min_val)
                        node_def.attr['max'].f = float(max_val)
                    except ValueError:
                        warnings.warn('%s does not have min/max values.' %
                                      node_id)

        calibrated_model_dir = tempfile.mkdtemp()
        v1_builder = builder.SavedModelBuilder(calibrated_model_dir)

        with session.Session(graph=ops.Graph()) as sess:
            importer.import_graph_def(graph_def, name='')
            working_graph = ops.get_default_graph()
            graph_def = working_graph.as_graph_def()

            v1_builder.add_meta_graph_and_variables(
                sess, [tag_constants.SERVING], signature_def_map=signatures)

        v1_builder.save()
        signatures = _get_signatures_from_saved_model(calibrated_model_dir,
                                                      signature_keys, tags)

        graph_def_serialized = (
            quantize_model_wrapper.quantize_ptq_model_post_calibration(
                calibrated_model_dir,
                ','.join(signature_keys),
                ','.join(tags),
            ))

    graph_def = graph_pb2.GraphDef()
    graph_def.ParseFromString(graph_def_serialized)

    if output_directory is None:
        output_directory = tempfile.mkdtemp()
    v1_builder = builder.SavedModelBuilder(output_directory)

    with session.Session(graph=ops.Graph()) as sess:
        importer.import_graph_def(graph_def, name='')
        working_graph = ops.get_default_graph()

        signatures = _fix_tensor_names(signatures, working_graph)
        if signatures is None:
            raise ValueError(
                "The input SavedModel doesn't contain a valid signature")

        v1_builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],
                                                signature_def_map=signatures)

    v1_builder.save()

    return saved_model_load(output_directory)
예제 #8
0
def _static_range_quantize(saved_model_path: str,
                           signature_keys=None,
                           tags=None,
                           output_directory=None,
                           representative_dataset=None):
  """Quantizes the given SavedModel via static range quantization.

  Args:
    saved_model_path: Path to the saved model. When representative_dataset is
      not provided, this should be a model trained with QAT.
    signature_keys: List of keys identifying SignatureDef containing inputs and
      outputs.
    tags: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze.
    output_directory: The path to save the output SavedModel (must be an empty
      directory).
    representative_dataset: a generator that returns a dictionary in
      {input_name: input_tensor} format or a tuple with signature key and a
      dictionary in {input_name: input_tensor} format that feeds calibration
      data for quantizing model. This should be provided when the model is not a
      QAT model.

  Returns:
    A SavedModel object with TF quantization applied.

  Raises:
    ValueError: when representative_dataset is not provided for non-QAT model.
  """
  is_qat_saved_model = _is_qat_saved_model(saved_model_path)
  signatures = _get_signatures_from_saved_model(saved_model_path,
                                                signature_keys, tags)

  # Checks if the model is from QAT
  if representative_dataset is None and not is_qat_saved_model:
    raise ValueError(
        'When `representative_dataset` is not provided, the model should be '
        'trained with quantization-aware training (QAT).')

  if is_qat_saved_model:
    # Handle QAT models are supported.
    graph_def_serialized = (
        quantize_model_wrapper.quantize_qat_model(saved_model_path,
                                                  ','.join(signature_keys),
                                                  ','.join(tags)))
  else:
    # Handle PTQ models are supported with mocking calibration.
    graph_def_serialized = (
        quantize_model_wrapper.quantize_ptq_model_pre_calibration(
            saved_model_path, ','.join(signature_keys), ','.join(tags)))

    graph_def = graph_pb2.GraphDef()
    graph_def.ParseFromString(graph_def_serialized)

    float_model_dir = tempfile.mkdtemp()
    v1_builder = builder.SavedModelBuilder(float_model_dir)

    with session.Session(graph=ops.Graph()) as sess:
      for function_def in graph_def.library.function:
        for node_def in function_def.node_def:
          if node_def.op == 'CustomAggregator':
            node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii')

      importer.import_graph_def(graph_def, name='')
      working_graph = ops.get_default_graph()
      graph_def = working_graph.as_graph_def()

      signatures = _fix_tensor_names(signatures, working_graph)
      if signatures is None:
        raise ValueError(
            "The input SavedModel doesn't contain a valid signature")

      v1_builder.add_meta_graph_and_variables(
          sess, [tag_constants.SERVING], signature_def_map=signatures)

    v1_builder.save()

    float_model = saved_model_load(float_model_dir)

    # Uses the representative dataset to collect statistics for calibration.
    # Handles the graph mode execution separately in case TF2 is disabled or
    # eager execution is disabled.
    if context.executing_eagerly():
      _run_graph_for_calibration_eager_mode(float_model, signature_keys,
                                            representative_dataset)
    else:
      _run_graph_for_calibration_graph_mode(float_model, signature_keys,
                                            representative_dataset)

    for function_def in graph_def.library.function:
      for node_def in function_def.node_def:
        if node_def.op == 'CustomAggregator':
          node_id = node_def.attr['id'].s
          try:
            min_val = quantize_model_wrapper.get_min_from_calibrator(node_id)
            max_val = quantize_model_wrapper.get_max_from_calibrator(node_id)
            quantize_model_wrapper.clear_data_from_calibrator(node_id)
            node_def.attr['min'].f = float(min_val)
            node_def.attr['max'].f = float(max_val)
          except ValueError:
            warnings.warn(
                f'CustomAggregator id "{node_id.decode("utf-8")}" from '
                f'FunctionDef "{function_def.signature.name}" does not have '
                'min or max values. This function may not be quantized.')

    calibrated_model_dir = tempfile.mkdtemp()
    v1_builder = builder.SavedModelBuilder(calibrated_model_dir)

    with session.Session(graph=ops.Graph()) as sess:
      importer.import_graph_def(graph_def, name='')
      working_graph = ops.get_default_graph()
      graph_def = working_graph.as_graph_def()

      v1_builder.add_meta_graph_and_variables(
          sess, [tag_constants.SERVING], signature_def_map=signatures)

    v1_builder.save()
    signatures = _get_signatures_from_saved_model(calibrated_model_dir,
                                                  signature_keys, tags)

    graph_def_serialized = (
        quantize_model_wrapper.quantize_ptq_model_post_calibration(
            calibrated_model_dir,
            ','.join(signature_keys),
            ','.join(tags),
        ))

  graph_def = graph_pb2.GraphDef()
  graph_def.ParseFromString(graph_def_serialized)

  if output_directory is None:
    output_directory = tempfile.mkdtemp()
  v1_builder = builder.SavedModelBuilder(output_directory)

  with session.Session(graph=ops.Graph()) as sess:
    importer.import_graph_def(graph_def, name='')
    working_graph = ops.get_default_graph()

    signatures = _fix_tensor_names(signatures, working_graph)
    if signatures is None:
      raise ValueError("The input SavedModel doesn't contain a valid signature")

    v1_builder.add_meta_graph_and_variables(
        sess, [tag_constants.SERVING], signature_def_map=signatures)

  v1_builder.save()

  return saved_model_load(output_directory)