Exemple #1
0
  def test_ptq_model_with_wrong_tags_raises_error(self):
    input_saved_model_path = self.create_tempdir('input').full_path
    signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    save_tags = {tag_constants.TRAINING, tag_constants.GPU}

    input_placeholder = _create_and_save_tf1_conv_model(
        input_saved_model_path,
        signature_key,
        save_tags,
        input_key='input',
        output_key='output',
        use_variable=True)

    signature_keys = [signature_key]
    output_directory = self.create_tempdir().full_path

    quantization_options = quant_opts_pb2.QuantizationOptions(
        quantization_method=quant_opts_pb2.QuantizationMethod(
            experimental_method=_ExperimentalMethod.STATIC_RANGE))

    # Try to use a different set of tags to quantize.
    tags = {tag_constants.SERVING}
    data_gen = _create_data_generator(
        input_key='input', shape=input_placeholder.shape)
    with self.assertRaisesRegex(RuntimeError,
                                'Failed to retrieve MetaGraphDef'):
      quantize_model.quantize(
          input_saved_model_path,
          signature_keys,
          tags,
          output_directory,
          quantization_options,
          representative_dataset=data_gen)
Exemple #2
0
  def test_model_use_representative_samples_list(self):
    model = self.MatmulModel()
    input_savedmodel_dir = self.create_tempdir('input').full_path
    saved_model_save.save(model, input_savedmodel_dir)

    quantization_options = quant_opts_pb2.QuantizationOptions(
        quantization_method=quant_opts_pb2.QuantizationMethod(
            experimental_method=_ExperimentalMethod.STATIC_RANGE))
    output_savedmodel_dir = self.create_tempdir().full_path
    tags = {tag_constants.SERVING}

    representative_dataset = [{
        'input_tensor': random_ops.random_uniform(shape=(1, 4))
    } for _ in range(128)]

    converted_model = quantize_model.quantize(
        input_savedmodel_dir, ['serving_default'],
        output_directory=output_savedmodel_dir,
        quantization_options=quantization_options,
        representative_dataset=representative_dataset)

    self.assertIsNotNone(converted_model)
    self.assertEqual(
        list(converted_model.signatures._signatures.keys()),
        ['serving_default'])
    output_loader = saved_model_loader.SavedModelLoader(output_savedmodel_dir)
    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
    # Model is not quantized because there was no sample data for calibration.
    self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
Exemple #3
0
  def test_matmul_ptq_model(self, activation_fn, has_bias):
    model = self.MatmulModel(has_bias, activation_fn)
    input_saved_model_path = self.create_tempdir('input').full_path
    saved_model_save.save(model, input_saved_model_path)

    def data_gen():
      for _ in range(255):
        yield {
            'input_tensor':
                ops.convert_to_tensor(
                    np.random.uniform(low=0, high=5, size=(1, 4)).astype('f4')),
        }

    tags = [tag_constants.SERVING]
    output_directory = self.create_tempdir().full_path

    quantization_options = quant_opts_pb2.QuantizationOptions(
        quantization_method=quant_opts_pb2.QuantizationMethod(
            experimental_method=_ExperimentalMethod.STATIC_RANGE))

    converted_model = quantize_model.quantize(
        input_saved_model_path, ['serving_default'],
        tags,
        output_directory,
        quantization_options,
        representative_dataset=data_gen())
    self.assertIsNotNone(converted_model)
    self.assertEqual(
        list(converted_model.signatures._signatures.keys()),
        ['serving_default'])

    output_loader = saved_model_loader.SavedModelLoader(output_directory)
    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
    self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
Exemple #4
0
  def test_ptq_model_with_tf1_saved_model_invalid_input_key_raises_value_error(
      self):
    input_saved_model_path = self.create_tempdir('input').full_path
    signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    tags = {tag_constants.SERVING}

    input_placeholder = _create_and_save_tf1_conv_model(
        input_saved_model_path,
        signature_key,
        tags,
        input_key='x',
        output_key='output',
        use_variable=False)

    signature_keys = [signature_key]
    output_directory = self.create_tempdir().full_path

    quantization_options = quant_opts_pb2.QuantizationOptions(
        quantization_method=quant_opts_pb2.QuantizationMethod(
            experimental_method=_ExperimentalMethod.STATIC_RANGE))

    # Representative generator function that yields with an invalid input key.
    invalid_data_gen = _create_data_generator(
        input_key='invalid_input_key', shape=input_placeholder.shape)

    with self.assertRaisesRegex(
        ValueError,
        'Failed to run graph for post-training quantization calibration'):
      quantize_model.quantize(
          input_saved_model_path,
          signature_keys,
          tags,
          output_directory,
          quantization_options,
          representative_dataset=invalid_data_gen)
    def test_matmul_model(self):
        class MatmulModel(module.Module):
            @def_function.function(input_signature=[
                tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32)
            ])
            def matmul(self, input_tensor):
                filters = np.random.uniform(low=-1.0, high=1.0,
                                            size=(4, 3)).astype('f4')
                out = math_ops.matmul(input_tensor, filters)
                return {'output': out}

        model = MatmulModel()
        input_saved_model_path = self.create_tempdir('input').full_path
        saved_model_save.save(model, input_saved_model_path)

        tags = [tag_constants.SERVING]
        output_directory = self.create_tempdir().full_path

        quantization_options = quant_opts_pb2.QuantizationOptions(
            quantization_method=quant_opts_pb2.QuantizationMethod(
                experimental_method=_ExperimentalMethod.DYNAMIC_RANGE))

        converted_model = quantize_model.quantize(input_saved_model_path,
                                                  ['serving_default'], tags,
                                                  output_directory,
                                                  quantization_options)
        self.assertIsNotNone(converted_model)
        self.assertEqual(list(converted_model.signatures._signatures.keys()),
                         ['serving_default'])

        output_loader = saved_model_loader.SavedModelLoader(output_directory)
        output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
        self.assertTrue(
            _contains_quantized_function_call(output_meta_graphdef))
    def test_model_no_representative_sample_shows_warnings(self):
        class SimpleMatmulModel(module.Module):
            @def_function.function(input_signature=[
                tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32)
            ])
            def matmul(self, input_tensor):
                filters = random_ops.random_uniform(shape=(4, 3),
                                                    minval=-1.,
                                                    maxval=1.)
                bias = random_ops.random_uniform(shape=(3, ),
                                                 minval=-1.,
                                                 maxval=1.)

                out = math_ops.matmul(input_tensor, filters)
                out = nn_ops.bias_add(out, bias)
                return {'output': out}

        model = SimpleMatmulModel()
        input_savedmodel_dir = self.create_tempdir('input').full_path
        output_savedmodel_dir = self.create_tempdir().full_path
        saved_model_save.save(model, input_savedmodel_dir)

        tags = [tag_constants.SERVING]
        quantization_options = quant_opts_pb2.QuantizationOptions(
            quantization_method=quant_opts_pb2.QuantizationMethod(
                experimental_method=_ExperimentalMethod.STATIC_RANGE))

        with warnings.catch_warnings(record=True) as warnings_list:
            converted_model = quantize_model.quantize(
                input_savedmodel_dir,
                ['serving_default'],
                tags,
                output_savedmodel_dir,
                quantization_options,
                # Put no sample into the representative dataset to make calibration
                # impossible.
                representative_dataset=lambda: [])

            self.assertNotEmpty(warnings_list)

            # Warning message should contain the function name.
            self.assertTrue(self._any_warning_contains('matmul',
                                                       warnings_list))
            self.assertTrue(
                self._any_warning_contains('does not have min or max values',
                                           warnings_list))

        self.assertIsNotNone(converted_model)
        self.assertEqual(list(converted_model.signatures._signatures.keys()),
                         ['serving_default'])
        output_loader = saved_model_loader.SavedModelLoader(
            output_savedmodel_dir)
        output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
        # Model is not quantized because there was no sample data for calibration.
        self.assertFalse(
            _contains_quantized_function_call(output_meta_graphdef))
    def test_method_unspecified_raises_value_error(self):
        model = self.SimpleModel()

        input_saved_model_path = self.create_tempdir('input').full_path
        saved_model_save.save(model, input_saved_model_path)

        options = quant_opts_pb2.QuantizationOptions(
            quantization_method=quant_opts_pb2.QuantizationMethod(
                method=_Method.METHOD_UNSPECIFIED))

        with self.assertRaises(ValueError):
            quantize_model.quantize(input_saved_model_path,
                                    quantization_options=options)
    def test_invalid_method_raises_value_error(self):
        model = self.SimpleModel()

        input_saved_model_path = self.create_tempdir('input').full_path
        saved_model_save.save(model, input_saved_model_path)

        # Set an invalid value of -1 to QuantizationMethod.method.
        options = quant_opts_pb2.QuantizationOptions(
            quantization_method=quant_opts_pb2.QuantizationMethod(method=-1))

        with self.assertRaises(ValueError):
            quantize_model.quantize(input_saved_model_path,
                                    quantization_options=options)
    def test_matmul_ptq_model(self, activation_fn, has_bias):
        class MatmulModel(module.Module):
            @def_function.function(input_signature=[
                tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32)
            ])
            def matmul(self, input_tensor):
                filters = np.random.uniform(low=-1.0, high=1.0,
                                            size=(4, 3)).astype('f4')
                bias = np.random.uniform(low=-1.0, high=1.0,
                                         size=(3, )).astype('f4')
                out = math_ops.matmul(input_tensor, filters)
                if has_bias:
                    out = nn_ops.bias_add(out, bias)
                if activation_fn is not None:
                    out = activation_fn(out)
                return {'output': out}

        model = MatmulModel()
        input_saved_model_path = self.create_tempdir('input').full_path
        saved_model_save.save(model, input_saved_model_path)

        def data_gen():
            for _ in range(255):
                yield {
                    'input_tensor':
                    ops.convert_to_tensor(
                        np.random.uniform(low=0, high=5,
                                          size=(1, 4)).astype('f4')),
                }

        tags = [tag_constants.SERVING]
        output_directory = self.create_tempdir().full_path

        quantization_options = quant_opts_pb2.QuantizationOptions(
            quantization_method=quant_opts_pb2.QuantizationMethod(
                experimental_method=_ExperimentalMethod.STATIC_RANGE))

        converted_model = quantize_model.quantize(
            input_saved_model_path, ['serving_default'],
            tags,
            output_directory,
            quantization_options,
            representative_dataset=data_gen)
        self.assertIsNotNone(converted_model)
        self.assertEqual(list(converted_model.signatures._signatures.keys()),
                         ['serving_default'])

        output_loader = saved_model_loader.SavedModelLoader(output_directory)
        output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
        self.assertTrue(
            _contains_quantized_function_call(output_meta_graphdef))
Exemple #10
0
  def test_conv_model(self):

    class ConvModel(module.Module):

      @def_function.function(input_signature=[
          tensor_spec.TensorSpec(shape=[1, 3, 4, 3], dtype=dtypes.float32)
      ])
      def conv(self, input_tensor):
        filters = np.random.uniform(
            low=-10, high=10, size=(2, 3, 3, 2)).astype('f4')
        bias = np.random.uniform(low=0, high=10, size=(2)).astype('f4')
        out = nn_ops.conv2d(
            input_tensor,
            filters,
            strides=[1, 1, 2, 1],
            dilations=[1, 1, 1, 1],
            padding='SAME',
            data_format='NHWC')
        out = nn_ops.bias_add(out, bias, data_format='NHWC')
        out = nn_ops.relu6(out)
        return {'output': out}

    model = ConvModel()
    input_saved_model_path = self.create_tempdir('input').full_path
    saved_model_save.save(model, input_saved_model_path)

    tags = [tag_constants.SERVING]
    output_directory = self.create_tempdir().full_path

    quantization_options = quant_opts_pb2.QuantizationOptions(
        quantization_method=quant_opts_pb2.QuantizationMethod(
            experimental_method=_ExperimentalMethod.DYNAMIC_RANGE))

    converted_model = quantize_model.quantize(input_saved_model_path,
                                              ['serving_default'], tags,
                                              output_directory,
                                              quantization_options)

    self.assertIsNotNone(converted_model)
    self.assertEqual(
        list(converted_model.signatures._signatures.keys()),
        ['serving_default'])

    output_loader = saved_model_loader.SavedModelLoader(output_directory)
    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
    # Currently conv is not supported.
    self.assertFalse(_contains_quantized_function_call(output_meta_graphdef))
Exemple #11
0
    def _convert_with_calibration(self):
        class ModelWithAdd(autotrackable.AutoTrackable):
            """Basic model with addition."""
            @def_function.function(input_signature=[
                tensor_spec.TensorSpec(shape=[10],
                                       dtype=dtypes.float32,
                                       name='x'),
                tensor_spec.TensorSpec(shape=[10],
                                       dtype=dtypes.float32,
                                       name='y')
            ])
            def add(self, x, y):
                res = math_ops.add(x, y)
                return {'output': res}

        def data_gen():
            for _ in range(255):
                yield {
                    'x':
                    ops.convert_to_tensor(
                        np.random.uniform(size=(10)).astype('f4')),
                    'y':
                    ops.convert_to_tensor(
                        np.random.uniform(size=(10)).astype('f4'))
                }

        root = ModelWithAdd()

        temp_path = self.create_tempdir().full_path
        saved_model_save.save(root,
                              temp_path,
                              signatures=root.add.get_concrete_function())

        quantization_options = quant_opts_pb2.QuantizationOptions(
            quantization_method=quant_opts_pb2.QuantizationMethod(
                experimental_method=quant_opts_pb2.QuantizationMethod.
                ExperimentalMethod.STATIC_RANGE))

        model = quantize_model.quantize(
            temp_path, ['serving_default'], [tag_constants.SERVING],
            quantization_options=quantization_options,
            representative_dataset=data_gen())
        return model
Exemple #12
0
  def test_ptq_model_with_non_default_tags(self):
    input_saved_model_path = self.create_tempdir('input').full_path
    signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    # Use a different set of tags other than {"serve"}.
    tags = {tag_constants.TRAINING, tag_constants.GPU}

    # Non-default tags are usually used when saving multiple metagraphs in TF1.
    input_placeholder = _create_and_save_tf1_conv_model(
        input_saved_model_path,
        signature_key,
        tags,
        input_key='input',
        output_key='output',
        use_variable=True)

    signature_keys = [signature_key]
    output_directory = self.create_tempdir().full_path

    quantization_options = quant_opts_pb2.QuantizationOptions(
        quantization_method=quant_opts_pb2.QuantizationMethod(
            experimental_method=_ExperimentalMethod.STATIC_RANGE))

    data_gen = _create_data_generator(
        input_key='input', shape=input_placeholder.shape)

    converted_model = quantize_model.quantize(
        input_saved_model_path,
        signature_keys,
        tags,
        output_directory,
        quantization_options,
        representative_dataset=data_gen)

    self.assertIsNotNone(converted_model)
    self.assertEqual(
        list(converted_model.signatures._signatures.keys()), signature_keys)

    output_loader = saved_model_loader.SavedModelLoader(output_directory)
    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
    self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
Exemple #13
0
  def test_model_no_representative_sample_shows_warnings(self):
    model = self.MatmulModel()
    input_savedmodel_dir = self.create_tempdir('input').full_path
    output_savedmodel_dir = self.create_tempdir().full_path
    saved_model_save.save(model, input_savedmodel_dir)

    tags = [tag_constants.SERVING]
    quantization_options = quant_opts_pb2.QuantizationOptions(
        quantization_method=quant_opts_pb2.QuantizationMethod(
            experimental_method=_ExperimentalMethod.STATIC_RANGE))

    with warnings.catch_warnings(record=True) as warnings_list:
      converted_model = quantize_model.quantize(
          input_savedmodel_dir,
          ['serving_default'],
          tags,
          output_savedmodel_dir,
          quantization_options,
          # Put no sample into the representative dataset to make calibration
          # impossible.
          representative_dataset=[])

      self.assertNotEmpty(warnings_list)

      # Warning message should contain the function name.
      self.assertTrue(self._any_warning_contains('matmul', warnings_list))
      self.assertTrue(
          self._any_warning_contains('does not have min or max values',
                                     warnings_list))

    self.assertIsNotNone(converted_model)
    self.assertEqual(
        list(converted_model.signatures._signatures.keys()),
        ['serving_default'])
    output_loader = saved_model_loader.SavedModelLoader(output_savedmodel_dir)
    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
    # Model is not quantized because there was no sample data for calibration.
    self.assertFalse(_contains_quantized_function_call(output_meta_graphdef))
Exemple #14
0
  def test_ptq_model_with_tf1_saved_model(self):
    input_saved_model_path = self.create_tempdir('input').full_path
    tags = {tag_constants.SERVING}
    signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY

    input_placeholder = _create_and_save_tf1_conv_model(
        input_saved_model_path,
        signature_key,
        tags,
        input_key='p',
        output_key='output',
        use_variable=False)

    signature_keys = [signature_key]
    output_directory = self.create_tempdir().full_path

    quantization_options = quant_opts_pb2.QuantizationOptions(
        quantization_method=quant_opts_pb2.QuantizationMethod(
            experimental_method=_ExperimentalMethod.STATIC_RANGE))

    data_gen = _create_data_generator(
        input_key='p', shape=input_placeholder.shape)

    converted_model = quantize_model.quantize(
        input_saved_model_path,
        signature_keys,
        tags,
        output_directory,
        quantization_options,
        representative_dataset=data_gen)

    self.assertIsNotNone(converted_model)
    self.assertEqual(
        list(converted_model.signatures._signatures.keys()), signature_keys)

    output_loader = saved_model_loader.SavedModelLoader(output_directory)
    output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
    self.assertTrue(_contains_quantized_function_call(output_meta_graphdef))
    def test_model_with_uncalibrated_subgraph(self):
        class IfModel(module.Module):
            @def_function.function(input_signature=[
                tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32)
            ])
            def model_fn(self, x):
                if math_ops.reduce_sum(x) > 10.0:
                    filters = np.random.uniform(low=-1.0,
                                                high=1.0,
                                                size=(4, 3)).astype('f4')
                    bias = np.random.uniform(low=-1.0, high=1.0,
                                             size=(3, )).astype('f4')
                    out = math_ops.matmul(x, filters)
                    out = nn_ops.bias_add(out, bias)
                    return {'output': out}

                filters = np.random.uniform(low=-1.0, high=1.0,
                                            size=(4, 3)).astype('f4')
                bias = np.random.uniform(low=-1.0, high=1.0,
                                         size=(3, )).astype('f4')
                out = math_ops.matmul(x, filters)
                out = nn_ops.bias_add(out, bias)
                return {'output': out}

        model = IfModel()
        input_saved_model_path = self.create_tempdir('input').full_path
        saved_model_save.save(model, input_saved_model_path)

        def data_gen():
            for _ in range(10):
                yield {
                    'x':
                    ops.convert_to_tensor(
                        np.random.uniform(low=0.0, high=1.0,
                                          size=(1, 4)).astype('f4')),
                }

        tags = [tag_constants.SERVING]
        output_directory = self.create_tempdir().full_path

        quantization_options = quant_opts_pb2.QuantizationOptions(
            quantization_method=quant_opts_pb2.QuantizationMethod(
                experimental_method=_ExperimentalMethod.STATIC_RANGE))

        with warnings.catch_warnings(record=True) as warnings_list:
            converted_model = quantize_model.quantize(
                input_saved_model_path, ['serving_default'],
                tags,
                output_directory,
                quantization_options,
                representative_dataset=data_gen)

            self.assertNotEmpty(warnings_list)

            # Warning message should contain the function name. The uncalibrated path
            # is when the condition is true, so 'cond_true' function must be part of
            # the warning message.
            self.assertTrue(
                self._any_warning_contains('cond_true', warnings_list))
            self.assertFalse(
                self._any_warning_contains('cond_false', warnings_list))
            self.assertTrue(
                self._any_warning_contains('does not have min or max values',
                                           warnings_list))

        self.assertIsNotNone(converted_model)
        self.assertEqual(list(converted_model.signatures._signatures.keys()),
                         ['serving_default'])
        output_loader = saved_model_loader.SavedModelLoader(output_directory)
        output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
        self.assertTrue(
            _contains_quantized_function_call(output_meta_graphdef))
    def test_depthwise_conv_ptq_model(self, activation_fn, has_bias, has_bn):
        class DepthwiseConvModel(module.Module):
            @def_function.function(input_signature=[
                tensor_spec.TensorSpec(shape=[1, 3, 4, 3],
                                       dtype=dtypes.float32)
            ])
            def conv(self, input_tensor):
                filters = np.random.uniform(low=-10,
                                            high=10,
                                            size=(2, 3, 3, 1)).astype('f4')
                bias = np.random.uniform(low=0, high=10, size=(3)).astype('f4')
                scale, offset = [1.0, 1.0, 1.0], [0.5, 0.5, 0.5]
                mean, variance = scale, offset
                out = nn_ops.depthwise_conv2d_native(input_tensor,
                                                     filters,
                                                     strides=[1, 2, 2, 1],
                                                     dilations=[1, 1, 1, 1],
                                                     padding='SAME',
                                                     data_format='NHWC')
                if has_bias:
                    out = nn_ops.bias_add(out, bias)
                if has_bn:
                    # Fusing is supported for non-training case.
                    out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3(
                        out, scale, offset, mean, variance, is_training=False)
                if activation_fn is not None:
                    out = activation_fn(out)
                return {'output': out}

        model = DepthwiseConvModel()
        input_saved_model_path = self.create_tempdir('input').full_path
        saved_model_save.save(model, input_saved_model_path)

        def data_gen():
            for _ in range(255):
                yield {
                    'input_tensor':
                    ops.convert_to_tensor(
                        np.random.uniform(low=0, high=150,
                                          size=(1, 3, 4, 3)).astype('f4')),
                }

        tags = [tag_constants.SERVING]
        output_directory = self.create_tempdir().full_path

        quantization_options = quant_opts_pb2.QuantizationOptions(
            quantization_method=quant_opts_pb2.QuantizationMethod(
                experimental_method=_ExperimentalMethod.STATIC_RANGE))

        converted_model = quantize_model.quantize(
            input_saved_model_path, ['serving_default'],
            tags,
            output_directory,
            quantization_options,
            representative_dataset=data_gen)
        self.assertIsNotNone(converted_model)
        self.assertEqual(list(converted_model.signatures._signatures.keys()),
                         ['serving_default'])

        output_loader = saved_model_loader.SavedModelLoader(output_directory)
        output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
        self.assertTrue(
            _contains_quantized_function_call(output_meta_graphdef))
        self.assertFalse(_contains_op(output_meta_graphdef,
                                      'FusedBatchNormV3'))
    def test_ptq_model_with_variable(self):
        class ConvModelWithVariable(module.Module):
            """A simple model that performs a single convolution to the input tensor.

      It keeps the filter as a tf.Variable.
      """
            def __init__(self):
                self.filters = variables.Variable(
                    random_ops.random_uniform(shape=(2, 3, 3, 2),
                                              minval=-1.,
                                              maxval=1.))

            @def_function.function(input_signature=[
                tensor_spec.TensorSpec(name='input',
                                       shape=(1, 3, 4, 3),
                                       dtype=dtypes.float32),
            ])
            def __call__(self, x):
                out = nn_ops.conv2d(x,
                                    self.filters,
                                    strides=[1, 1, 2, 1],
                                    dilations=[1, 1, 1, 1],
                                    padding='SAME',
                                    data_format='NHWC')
                return {'output': out}

        def gen_data():
            for _ in range(255):
                yield {
                    'input':
                    random_ops.random_uniform(shape=(1, 3, 4, 3),
                                              minval=0,
                                              maxval=150)
                }

        model = ConvModelWithVariable()
        input_saved_model_path = self.create_tempdir('input').full_path
        saved_model_save.save(model, input_saved_model_path)

        signature_keys = [
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
        ]
        tags = {tag_constants.SERVING}
        output_directory = self.create_tempdir().full_path

        quantization_options = quant_opts_pb2.QuantizationOptions(
            quantization_method=quant_opts_pb2.QuantizationMethod(
                experimental_method=_ExperimentalMethod.STATIC_RANGE))

        converted_model = quantize_model.quantize(
            input_saved_model_path,
            signature_keys,
            tags,
            output_directory,
            quantization_options,
            representative_dataset=gen_data)

        self.assertIsNotNone(converted_model)
        self.assertEqual(list(converted_model.signatures._signatures.keys()),
                         signature_keys)

        output_loader = saved_model_loader.SavedModelLoader(output_directory)
        output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
        self.assertTrue(
            _contains_quantized_function_call(output_meta_graphdef))
    def test_qat_conv_model(self, activation_fn, has_bias):
        class ConvModel(module.Module):
            @def_function.function(input_signature=[
                tensor_spec.TensorSpec(name='input',
                                       shape=[1, 3, 4, 3],
                                       dtype=dtypes.float32),
                tensor_spec.TensorSpec(name='filter',
                                       shape=[2, 3, 3, 2],
                                       dtype=dtypes.float32),
            ])
            def conv(self, input_tensor, filter_tensor):
                q_input = array_ops.fake_quant_with_min_max_args(
                    input_tensor,
                    min=-0.1,
                    max=0.2,
                    num_bits=8,
                    narrow_range=False)
                q_filters = array_ops.fake_quant_with_min_max_args(
                    filter_tensor,
                    min=-1.0,
                    max=2.0,
                    num_bits=8,
                    narrow_range=False)
                bias = array_ops.constant([0, 0], dtype=dtypes.float32)
                out = nn_ops.conv2d(q_input,
                                    q_filters,
                                    strides=[1, 1, 2, 1],
                                    dilations=[1, 1, 1, 1],
                                    padding='SAME',
                                    data_format='NHWC')
                if has_bias:
                    out = nn_ops.bias_add(out, bias, data_format='NHWC')
                if activation_fn is not None:
                    out = activation_fn(out)
                q_out = array_ops.fake_quant_with_min_max_args(
                    out, min=-0.3, max=0.4, num_bits=8, narrow_range=False)
                return {'output': q_out}

        model = ConvModel()
        input_saved_model_path = self.create_tempdir('input').full_path
        saved_model_save.save(model, input_saved_model_path)

        signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
        tags = [tag_constants.SERVING]
        output_directory = self.create_tempdir().full_path

        quantization_options = quant_opts_pb2.QuantizationOptions(
            quantization_method=quant_opts_pb2.QuantizationMethod(
                experimental_method=_ExperimentalMethod.STATIC_RANGE))

        converted_model = quantize_model.quantize(input_saved_model_path,
                                                  [signature_key], tags,
                                                  output_directory,
                                                  quantization_options)
        self.assertIsNotNone(converted_model)
        self.assertEqual(list(converted_model.signatures._signatures.keys()),
                         [signature_key])

        input_data = np.random.uniform(low=-0.1, high=0.2,
                                       size=(1, 3, 4, 3)).astype('f4')
        filter_data = np.random.uniform(low=-0.5, high=0.5,
                                        size=(2, 3, 3, 2)).astype('f4')

        expected_outputs = model.conv(input_data, filter_data)
        got_outputs = converted_model.signatures[signature_key](
            input=ops.convert_to_tensor(input_data),
            filter=ops.convert_to_tensor(filter_data))
        # TODO(b/215633216): Check if the accuracy is acceptable.
        self.assertAllClose(expected_outputs, got_outputs, atol=0.01)

        output_loader = saved_model_loader.SavedModelLoader(output_directory)
        output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags)
        self.assertTrue(
            _contains_quantized_function_call(output_meta_graphdef))