def testGetsResultQuantizers_EmptyWhenFalse(self):
    layer = self._simple_dense_layer()
    quantize_config = tflite_quantize_registry.TFLiteQuantizeConfig([], [],
                                                                    False)

    output_quantizers = quantize_config.get_output_quantizers(layer)

    self.assertEqual([], output_quantizers)
  def testSetsQuantizeWeights_ErrorOnWrongShapeOfWeight(self):
    layer = self._simple_dense_layer()
    quantize_kernel = K.variable(np.ones([1, 2]))

    quantize_config = tflite_quantize_registry.TFLiteQuantizeConfig(
        ['kernel'], ['activation'], False)

    with self.assertRaises(ValueError):
      quantize_config.set_quantize_weights(layer, [quantize_kernel])
  def testGetsResultQuantizers_ReturnsQuantizer(self):
    layer = self._simple_dense_layer()
    quantize_config = tflite_quantize_registry.TFLiteQuantizeConfig([], [],
                                                                    True)

    output_quantizers = quantize_config.get_output_quantizers(layer)

    self.assertLen(output_quantizers, 1)
    self._assert_activation_quantizers(output_quantizers)
  def testSetsQuantizeActivations(self):
    layer = self._simple_dense_layer()
    quantize_activation = keras.activations.relu

    quantize_config = tflite_quantize_registry.TFLiteQuantizeConfig(
        ['kernel'], ['activation'], False)
    quantize_config.set_quantize_activations(layer, [quantize_activation])

    self.assertEqual(layer.activation, quantize_activation)
    def testSetsQuantizeWeights(self):
        layer = self._simple_dense_layer()
        quantize_kernel = K.variable(np.ones(layer.kernel.shape.as_list()))

        quantize_config = tflite_quantize_registry.TFLiteQuantizeConfig(
            ['kernel'], ['activation'], False)
        quantize_config.set_quantize_weights(layer, [quantize_kernel])

        self._assert_kernel_equality(layer.kernel, quantize_kernel)
  def testGetsQuantizeActivationsAndQuantizers(self):
    layer = self._simple_dense_layer()

    quantize_config = tflite_quantize_registry.TFLiteQuantizeConfig(
        ['kernel'], ['activation'], False)
    (activations, activation_quantizers) = self._convert_list(
        quantize_config.get_activations_and_quantizers(layer))

    self._assert_activation_quantizers(activation_quantizers)
    self.assertEqual([layer.activation], activations)
  def testGetsQuantizeWeightsAndQuantizers(self):
    layer = self._simple_dense_layer()

    quantize_config = tflite_quantize_registry.TFLiteQuantizeConfig(
        ['kernel'], ['activation'], False)
    (weights, weight_quantizers) = self._convert_list(
        quantize_config.get_weights_and_quantizers(layer))

    self._assert_weight_quantizers(weight_quantizers)
    self.assertEqual([layer.kernel], weights)
  def testSetsQuantizeActivations_ErrorOnWrongNumberOfActivations(self):
    layer = self._simple_dense_layer()
    quantize_activation = keras.activations.relu

    quantize_config = tflite_quantize_registry.TFLiteQuantizeConfig(
        ['kernel'], ['activation'], False)

    with self.assertRaises(ValueError):
      quantize_config.set_quantize_activations(layer, [])

    with self.assertRaises(ValueError):
      quantize_config.set_quantize_activations(
          layer, [quantize_activation, quantize_activation])
  def testSerialization(self):
    quantize_config = tflite_quantize_registry.TFLiteQuantizeConfig(
        ['kernel'], ['activation'], False)

    expected_config = {
        'class_name': 'TFLiteQuantizeConfig',
        'config': {
            'weight_attrs': ['kernel'],
            'activation_attrs': ['activation'],
            'quantize_output': False
        }
    }
    serialized_quantize_config = serialize_keras_object(quantize_config)

    self.assertEqual(expected_config, serialized_quantize_config)

    quantize_config_from_config = deserialize_keras_object(
        serialized_quantize_config,
        module_objects=globals(),
        custom_objects=tflite_quantize_registry._types_dict())

    self.assertEqual(quantize_config, quantize_config_from_config)