def testSerialization(self):
    quantize_config = n_bit_registry.DefaultNBitQuantizeConfig(
        ['kernel'], ['activation'], False,
        num_bits_weight=4, num_bits_activation=4)

    expected_config = {
        'class_name': 'DefaultNBitQuantizeConfig',
        'config': {
            'weight_attrs': ['kernel'],
            'activation_attrs': ['activation'],
            'quantize_output': False,
            'num_bits_weight': 4,
            'num_bits_activation': 4
        }
    }
    serialized_quantize_config = serialize_keras_object(quantize_config)

    self.assertEqual(expected_config, serialized_quantize_config)

    quantize_config_from_config = deserialize_keras_object(
        serialized_quantize_config,
        module_objects=globals(),
        custom_objects=n_bit_registry._types_dict())

    self.assertEqual(quantize_config, quantize_config_from_config)
  def testGetsResultQuantizers_EmptyWhenFalse(self):
    layer = self._simple_dense_layer()
    quantize_config = n_bit_registry.DefaultNBitQuantizeConfig(
        [], [], False, num_bits_weight=4, num_bits_activation=4)

    output_quantizers = quantize_config.get_output_quantizers(layer)

    self.assertEqual([], output_quantizers)
  def testGetsResultQuantizers_ReturnsQuantizer(self):
    layer = self._simple_dense_layer()
    quantize_config = n_bit_registry.DefaultNBitQuantizeConfig(
        [], [], True, num_bits_weight=4, num_bits_activation=4)

    output_quantizers = quantize_config.get_output_quantizers(layer)

    self.assertLen(output_quantizers, 1)
    self._assert_activation_quantizers(output_quantizers)
  def testSetsQuantizeWeights_ErrorOnWrongShapeOfWeight(self):
    layer = self._simple_dense_layer()
    quantize_kernel = K.variable(np.ones([1, 2]))

    quantize_config = n_bit_registry.DefaultNBitQuantizeConfig(
        ['kernel'], ['activation'], False,
        num_bits_weight=4, num_bits_activation=4)

    with self.assertRaises(ValueError):
      quantize_config.set_quantize_weights(layer, [quantize_kernel])
  def testSetsQuantizeActivations(self):
    layer = self._simple_dense_layer()
    quantize_activation = keras.activations.relu

    quantize_config = n_bit_registry.DefaultNBitQuantizeConfig(
        ['kernel'], ['activation'], False,
        num_bits_weight=4, num_bits_activation=4)
    quantize_config.set_quantize_activations(layer, [quantize_activation])

    self.assertEqual(layer.activation, quantize_activation)
  def testSetsQuantizeWeights(self):
    layer = self._simple_dense_layer()
    quantize_kernel = K.variable(np.ones(layer.kernel.shape.as_list()))

    quantize_config = n_bit_registry.DefaultNBitQuantizeConfig(
        ['kernel'], ['activation'], False,
        num_bits_weight=4, num_bits_activation=4)
    quantize_config.set_quantize_weights(layer, [quantize_kernel])

    self._assert_kernel_equality(layer.kernel, quantize_kernel)
  def testGetsQuantizeActivationsAndQuantizers(self):
    layer = self._simple_dense_layer()

    quantize_config = n_bit_registry.DefaultNBitQuantizeConfig(
        ['kernel'], ['activation'], False,
        num_bits_weight=4, num_bits_activation=4)
    (activations, activation_quantizers) = self._convert_list(
        quantize_config.get_activations_and_quantizers(layer))

    self._assert_activation_quantizers(activation_quantizers)
    self.assertEqual([layer.activation], activations)
  def testSetsQuantizeActivations_ErrorOnWrongNumberOfActivations(self):
    layer = self._simple_dense_layer()
    quantize_activation = keras.activations.relu

    quantize_config = n_bit_registry.DefaultNBitQuantizeConfig(
        ['kernel'], ['activation'], False,
        num_bits_weight=4, num_bits_activation=4)

    with self.assertRaises(ValueError):
      quantize_config.set_quantize_activations(layer, [])

    with self.assertRaises(ValueError):
      quantize_config.set_quantize_activations(
          layer, [quantize_activation, quantize_activation])