예제 #1
0
  def testGetsResultQuantizers_EmptyWhenFalse(self):
    layer = self._simple_dense_layer()
    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        [], [], False)

    output_quantizers = quantize_provider.get_output_quantizers(layer)

    self.assertEqual([], output_quantizers)
예제 #2
0
  def testGetsResultQuantizers_ReturnsQuantizer(self):
    layer = self._simple_dense_layer()
    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        [], [], True)

    output_quantizers = quantize_provider.get_output_quantizers(layer)

    self.assertLen(output_quantizers, 1)
    self._assert_activation_quantizers(output_quantizers)
예제 #3
0
  def testSetsQuantizeWeights_ErrorOnWrongShapeOfWeight(self):
    layer = self._simple_dense_layer()
    quantize_kernel = K.variable(np.ones([1, 2]))

    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        ['kernel'], ['activation'], False)

    with self.assertRaises(ValueError):
      quantize_provider.set_quantize_weights(layer, [quantize_kernel])
예제 #4
0
  def testSetsQuantizeActivations(self):
    layer = self._simple_dense_layer()
    quantize_activation = keras.activations.relu

    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        ['kernel'], ['activation'], False)
    quantize_provider.set_quantize_activations(layer, [quantize_activation])

    self.assertEqual(layer.activation, quantize_activation)
예제 #5
0
  def testSetsQuantizeWeights(self):
    layer = self._simple_dense_layer()
    quantize_kernel = K.variable(np.ones(layer.kernel.shape.as_list()))

    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        ['kernel'], ['activation'], False)
    quantize_provider.set_quantize_weights(layer, [quantize_kernel])

    self.assertEqual(layer.kernel, quantize_kernel)
예제 #6
0
  def testGetsQuantizeActivationsAndQuantizers(self):
    layer = self._simple_dense_layer()

    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        ['kernel'], ['activation'], False)
    (activations, activation_quantizers) = self._convert_list(
        quantize_provider.get_activations_and_quantizers(layer))

    self._assert_activation_quantizers(activation_quantizers)
    self.assertEqual([layer.activation], activations)
예제 #7
0
  def testGetsQuantizeWeightsAndQuantizers(self):
    layer = self._simple_dense_layer()

    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        ['kernel'], ['activation'], False)
    (weights, weight_quantizers) = self._convert_list(
        quantize_provider.get_weights_and_quantizers(layer))

    self._assert_weight_quantizers(weight_quantizers)
    self.assertEqual([layer.kernel], weights)
예제 #8
0
  def testSetsQuantizeActivations_ErrorOnWrongNumberOfActivations(self):
    layer = self._simple_dense_layer()
    quantize_activation = keras.activations.relu

    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        ['kernel'], ['activation'], False)

    with self.assertRaises(ValueError):
      quantize_provider.set_quantize_activations(layer, [])

    with self.assertRaises(ValueError):
      quantize_provider.set_quantize_activations(
          layer, [quantize_activation, quantize_activation])
    def testSetsQuantizeWeights_ErrorOnWrongNumberOfWeights(self):
        layer = self._simple_dense_layer()
        quantize_kernel = K.variable(np.ones(layer.kernel.shape.as_list()))

        quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
            ['kernel'], ['activation'])

        with self.assertRaises(ValueError):
            quantize_provider.set_quantize_weights(layer, [])

        with self.assertRaises(ValueError):
            quantize_provider.set_quantize_weights(
                layer, [quantize_kernel, quantize_kernel])
예제 #10
0
  def testSerialization(self):
    quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
        ['kernel'], ['activation'], False)

    expected_config = {
        'class_name': 'TFLiteQuantizeProvider',
        'config': {
            'weight_attrs': ['kernel'],
            'activation_attrs': ['activation'],
            'quantize_output': False
        }
    }
    serialized_quantize_provider = serialize_keras_object(quantize_provider)

    self.assertEqual(expected_config, serialized_quantize_provider)

    quantize_provider_from_config = deserialize_keras_object(
        serialized_quantize_provider,
        module_objects=globals(),
        custom_objects=tflite_quantize_registry._types_dict())

    self.assertEqual(quantize_provider, quantize_provider_from_config)