예제 #1
0
  def testDoesNotQuantizeNoOpActivation(self):
    layer = self.TestLayer()
    layer.activation = QuantizeAwareActivation(
        quantize_aware_activation.NoOpActivation(), self.quantizer, 0, layer)

    model = keras.Sequential([layer])

    x = np.array([[-2.0, -1.0, 1.0, 2.0]])
    self.assertAllClose(x, model.predict(x))
예제 #2
0
  def _replace(self, bn_layer_node, conv_layer_node):
    if _has_custom_quantize_config(bn_layer_node, conv_layer_node):
      return bn_layer_node

    conv_layer_node.layer['config']['activation'] = \
      keras.activations.serialize(quantize_aware_activation.NoOpActivation())
    bn_layer_node.metadata['quantize_config'] = \
      default_8bit_quantize_configs.Default8BitOutputQuantizeConfig()

    return bn_layer_node
예제 #3
0
  def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node):
    if _has_custom_quantize_config(
        relu_layer_node, bn_layer_node, conv_layer_node):
      return relu_layer_node

    conv_layer_node.layer['config']['activation'] = (
        keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
    bn_layer_node.metadata['quantize_config'] = (
        default_8bit_quantize_configs.NoOpQuantizeConfig())

    return relu_layer_node
  def replacement(self, match_layer):
    bn_layer_node, conv_layer_node = match_layer, match_layer.input_layers[0]

    if self._has_custom_quantize_provider(bn_layer_node, conv_layer_node):
      return match_layer

    conv_layer_node.layer['config']['activation'] = \
      keras.activations.serialize(quantize_aware_activation.NoOpActivation())
    bn_layer_node.metadata['quantize_provider'] = \
      tflite_quantize_providers.OutputQuantizeProvider()

    return match_layer
예제 #5
0
  def replacement(self, match_layer):
    bn_layer_node, conv_layer_node = match_layer, match_layer.input_layers[0]

    if self._has_custom_quantize_config(bn_layer_node, conv_layer_node):
      return match_layer

    conv_layer_node.layer['config']['activation'] = \
      keras.activations.serialize(quantize_aware_activation.NoOpActivation())
    bn_layer_node.metadata['quantize_config'] = \
      default_8bit_quantize_configs.Default8BitOutputQuantizeConfig()

    return match_layer
    def _replace(self, bn_layer_node, conv_layer_node):
        if _has_custom_quantize_config(bn_layer_node, conv_layer_node):
            return bn_layer_node

        conv_layer_node.layer['config']['activation'] = (
            keras.activations.serialize(
                quantize_aware_activation.NoOpActivation()))
        bn_layer_node.metadata['quantize_config'] = (
            configs.DefaultNBitOutputQuantizeConfig(
                num_bits_weight=self._num_bits_weight,
                num_bits_activation=self._num_bits_activation))

        return bn_layer_node
예제 #7
0
    def replacement(self, match_layer):
        relu_layer_node = match_layer
        bn_layer_node = relu_layer_node.input_layers[0]
        conv_layer_node = bn_layer_node.input_layers[0]

        if self._has_custom_quantize_config(relu_layer_node, bn_layer_node,
                                            conv_layer_node):
            return match_layer

        conv_layer_node.layer['config']['activation'] = \
          keras.activations.serialize(quantize_aware_activation.NoOpActivation())
        bn_layer_node.metadata['quantize_config'] = \
          tflite_quantize_configs.NoOpQuantizeConfig()

        return match_layer
예제 #8
0
  def testConstruction_SupportedAndUnsupportedActivations(self):
    layer = self.TestLayer()

    # Supported activations. No error thrown.
    QuantizeAwareActivation(activations.relu, self.quantizer, 0, layer)
    QuantizeAwareActivation(activations.softmax, self.quantizer, 0, layer)
    QuantizeAwareActivation(
        quantize_aware_activation.NoOpActivation(), self.quantizer, 0, layer)

    def custom_quantize(x):
      return x

    with self.assertRaises(ValueError) as cm:
      QuantizeAwareActivation(custom_quantize, self.quantizer, 0, layer)
    self.assertEqual(
        str(cm.exception), QuantizeAwareActivation._CUSTOM_ACTIVATION_ERR_MSG)
예제 #9
0
class QuantizeAwareQuantizationTest(tf.test.TestCase, parameterized.TestCase):

  def setUp(self):
    super(QuantizeAwareQuantizationTest, self).setUp()
    self.quantizer = MovingAverageQuantizer(
        num_bits=8, per_axis=False, symmetric=True, narrow_range=False)

  class TestLayer(keras.layers.Layer):

    def call(self, inputs, training=None):
      if training is None:
        training = K.learning_phase()

      self.activation.training = training
      # Going through `identity` to create a new tensor. TF throws an error
      # if input tensor is fetched during a run.
      return self.activation(tf.identity(inputs))

    def compute_output_shape(self, input_shape):
      return input_shape

  def testConstruction_SupportedAndUnsupportedActivations(self):
    layer = self.TestLayer()

    # Supported activations. No error thrown.
    QuantizeAwareActivation(activations.relu, self.quantizer, 0, layer)
    QuantizeAwareActivation(activations.softmax, self.quantizer, 0, layer)
    QuantizeAwareActivation(
        quantize_aware_activation.NoOpActivation(), self.quantizer, 0, layer)

    def custom_quantize(x):
      return x

    with self.assertRaises(ValueError) as cm:
      QuantizeAwareActivation(custom_quantize, self.quantizer, 0, layer)
    self.assertEqual(
        str(cm.exception), QuantizeAwareActivation._CUSTOM_ACTIVATION_ERR_MSG)

  def testAppliesQuantizationPostActivation(self):
    layer = self.TestLayer()
    layer.activation = QuantizeAwareActivation(
        activations.get('relu'), self.quantizer, 0, layer)

    model = keras.Sequential([layer])

    x = np.array([-6.0, -3.0, 0.0, 0.05, 0.1, 3.0, 6.0])
    # All negative values are removed due to ReLU. The other expected values
    # are the border values of float buckets when [-6, 6] range is quantized to
    # 256 buckets.
    # Derived using `tf.fake_quant_with_min_max_vars`
    expected_activation = np.array(
        [0.0, 0.0, 0.0, 0.04705906, 0.09411764, 3.011765,
         5.9764705]).reshape(7, 1)

    self.assertAllClose(expected_activation, model.predict(x))

  def testAppliesQuantizationPreActivation(self):
    layer = self.TestLayer()
    layer.activation = QuantizeAwareActivation(
        activations.get('softmax'), self.quantizer, 0, layer)

    model = keras.Sequential([layer])

    x = np.array([[1.0, 2.0]])
    # expected_activation is determined using the float buckets when [-6, 6] is
    # quantized. Derived using `tf.fake_quant_with_min_max_vars`. For sigmoid,
    # quantization is applied twice.
    #
    # FakeQuant([1.0, 2.0]) = [0.9882355, 1.9764705]
    # Softmax([0.9882355, 1.9764705]) = [0.27126083, 0.72873914]
    expected_activation = np.array([[0.27126083, 0.72873914]])

    self.assertAllClose(expected_activation, model.predict(x))

  def testDoesNotQuantizeNoOpActivation(self):
    layer = self.TestLayer()
    layer.activation = QuantizeAwareActivation(
        quantize_aware_activation.NoOpActivation(), self.quantizer, 0, layer)

    model = keras.Sequential([layer])

    x = np.array([[-2.0, -1.0, 1.0, 2.0]])
    self.assertAllClose(x, model.predict(x))

  @parameterized.parameters(
      (activations.get('relu'), {'activation': 'relu'}),
      (quantize_aware_activation.NoOpActivation(),
       {'activation': {'class_name': 'NoOpActivation', 'config': {}}})
  )
  def testSerializationReturnsWrappedActivation(
      self, activation, activation_config):
    quantize_activation = QuantizeAwareActivation(
        activation, self.quantizer, 0, self.TestLayer())
    serialized_quantize_activation = serialize_keras_object(quantize_activation)

    expected_config = {
        'class_name': 'QuantizeAwareActivation',
        'config': activation_config
    }
    self.assertEqual(expected_config, serialized_quantize_activation)

    deserialized_activation = deserialize_keras_object(
        serialized_quantize_activation,
        custom_objects={
            'QuantizeAwareActivation': QuantizeAwareActivation,
            'NoOpActivation': quantize_aware_activation.NoOpActivation
        })

    self.assertEqual(activation, deserialized_activation)