Beispiel #1
0
    def testGetsResultQuantizers_EmptyWhenFalse(self):
        layer = self._simple_dense_layer()
        quantize_config = default_8bit_quantize_registry.Default8BitQuantizeConfig(
            [], [], False)

        output_quantizers = quantize_config.get_output_quantizers(layer)

        self.assertEqual([], output_quantizers)
Beispiel #2
0
    def testSetsQuantizeWeights_ErrorOnWrongShapeOfWeight(self):
        layer = self._simple_dense_layer()
        quantize_kernel = K.variable(np.ones([1, 2]))

        quantize_config = default_8bit_quantize_registry.Default8BitQuantizeConfig(
            ['kernel'], ['activation'], False)

        with self.assertRaises(ValueError):
            quantize_config.set_quantize_weights(layer, [quantize_kernel])
Beispiel #3
0
    def testGetsResultQuantizers_ReturnsQuantizer(self):
        layer = self._simple_dense_layer()
        quantize_config = default_8bit_quantize_registry.Default8BitQuantizeConfig(
            [], [], True)

        output_quantizers = quantize_config.get_output_quantizers(layer)

        self.assertLen(output_quantizers, 1)
        self._assert_activation_quantizers(output_quantizers)
Beispiel #4
0
    def testSetsQuantizeActivations(self):
        layer = self._simple_dense_layer()
        quantize_activation = keras.activations.relu

        quantize_config = default_8bit_quantize_registry.Default8BitQuantizeConfig(
            ['kernel'], ['activation'], False)
        quantize_config.set_quantize_activations(layer, [quantize_activation])

        self.assertEqual(layer.activation, quantize_activation)
Beispiel #5
0
    def testSetsQuantizeWeights(self):
        layer = self._simple_dense_layer()
        quantize_kernel = K.variable(np.ones(layer.kernel.shape.as_list()))

        quantize_config = default_8bit_quantize_registry.Default8BitQuantizeConfig(
            ['kernel'], ['activation'], False)
        quantize_config.set_quantize_weights(layer, [quantize_kernel])

        self._assert_kernel_equality(layer.kernel, quantize_kernel)
Beispiel #6
0
    def testGetsQuantizeActivationsAndQuantizers(self):
        layer = self._simple_dense_layer()

        quantize_config = default_8bit_quantize_registry.Default8BitQuantizeConfig(
            ['kernel'], ['activation'], False)
        (activations, activation_quantizers) = self._convert_list(
            quantize_config.get_activations_and_quantizers(layer))

        self._assert_activation_quantizers(activation_quantizers)
        self.assertEqual([layer.activation], activations)
Beispiel #7
0
    def testGetsQuantizeWeightsAndQuantizers(self):
        layer = self._simple_dense_layer()

        quantize_config = default_8bit_quantize_registry.Default8BitQuantizeConfig(
            ['kernel'], ['activation'], False)
        (weights, weight_quantizers) = self._convert_list(
            quantize_config.get_weights_and_quantizers(layer))

        self._assert_weight_quantizers(weight_quantizers)
        self.assertEqual([layer.kernel], weights)
Beispiel #8
0
    def testSetsQuantizeActivations_ErrorOnWrongNumberOfActivations(self):
        layer = self._simple_dense_layer()
        quantize_activation = keras.activations.relu

        quantize_config = default_8bit_quantize_registry.Default8BitQuantizeConfig(
            ['kernel'], ['activation'], False)

        with self.assertRaises(ValueError):
            quantize_config.set_quantize_activations(layer, [])

        with self.assertRaises(ValueError):
            quantize_config.set_quantize_activations(
                layer, [quantize_activation, quantize_activation])
Beispiel #9
0
    def testSerialization(self):
        quantize_config = default_8bit_quantize_registry.Default8BitQuantizeConfig(
            ['kernel'], ['activation'], False)

        expected_config = {
            'class_name': 'Default8BitQuantizeConfig',
            'config': {
                'weight_attrs': ['kernel'],
                'activation_attrs': ['activation'],
                'quantize_output': False
            }
        }
        serialized_quantize_config = serialize_keras_object(quantize_config)

        self.assertEqual(expected_config, serialized_quantize_config)

        quantize_config_from_config = deserialize_keras_object(
            serialized_quantize_config,
            module_objects=globals(),
            custom_objects=default_8bit_quantize_registry._types_dict())

        self.assertEqual(quantize_config, quantize_config_from_config)