Example #1
0
 def test_all_zero_input_works(self):
   # Tests that encoding does not blow up with all-zero input. With
   # min_max=None, the derived min and max are identical, thus potential for
   # division by zero.
   stage = quantization.PerChannelPRNGUniformQuantizationEncodingStage(bits=8)
   test_data = self.run_one_to_many_encode_decode(stage,
                                                  lambda: tf.zeros([50]))
   self.assertAllEqual(np.zeros((50)).astype(np.float32), test_data.decoded_x)
Example #2
0
 def test_input_types(self, x_dtype):
   # Tests combinations of input dtypes.
   stage = quantization.PerChannelPRNGUniformQuantizationEncodingStage(bits=8)
   x = tf.random.normal([50], dtype=x_dtype)
   encode_params, decode_params = stage.get_params()
   encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                               decode_params)
   test_data = test_utils.TestData(x, encoded_x, decoded_x)
   test_data = self.evaluate_test_data(test_data)
Example #3
0
 def test_quantization_bits_stochastic_rounding(self, bits):
   stage = quantization.PerChannelPRNGUniformQuantizationEncodingStage(
       bits=bits)
   test_data = self.run_one_to_many_encode_decode(stage, self.default_input)
   self._assert_is_integer_float(test_data.encoded_x[
       quantization.PerChannelPRNGUniformQuantizationEncodingStage
       .ENCODED_VALUES_KEY])
   # For stochastic rounding, the potential error incurred by quantization
   # is bounded by the range of the input values divided by the number of
   # quantization buckets.
   self.assertAllClose(
       test_data.x, test_data.decoded_x, rtol=0.0, atol=2 / (2**bits - 1))
Example #4
0
  def test_quantization_empirically_unbiased(self):
    # Tests that the quantization "seems" to be unbiased.
    # Executing the encoding and decoding many times, the average error should
    # be a lot larger than the error of average decoded value.
    x = tf.constant(np.random.rand((50)).astype(np.float32))
    stage = quantization.PerChannelPRNGUniformQuantizationEncodingStage(bits=2)
    encode_params, decode_params = stage.get_params()
    encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                decode_params)
    test_data = test_utils.TestData(x, encoded_x, decoded_x)
    test_data_list = [self.evaluate_test_data(test_data) for _ in range(200)]

    norm_errors = []
    errors = []
    for data in test_data_list:
      norm_errors.append(np.linalg.norm(data.x - data.decoded_x))
      errors.append(data.x - data.decoded_x)
    mean_of_errors = np.mean(norm_errors)
    error_of_mean = np.linalg.norm(np.mean(errors, axis=0))
    self.assertGreaterEqual(mean_of_errors, error_of_mean * 10)
Example #5
0
  def test_dynamic_input_shape(self):
    # Tests that encoding works when the input shape is not statically known.
    stage = quantization.PerChannelPRNGUniformQuantizationEncodingStage(bits=8)
    shape = [10, 5, 7]
    prob = [0.5, 0.8, 0.6]
    original_x = tf.compat.v1.random.poisson(1.5, shape, dtype=tf.float32)
    rand = [tf.random.uniform([
        shape[i],
    ], dtype=tf.float32) for i in range(3)]
    sample_indices = [
        tf.reshape(tf.where(rand[i] < prob[i]), [-1]) for i in range(3)
    ]
    x = tf.gather(original_x, sample_indices[0], axis=0)
    x = tf.gather(x, sample_indices[1], axis=1)
    x = tf.gather(x, sample_indices[2], axis=2)

    encode_params, decode_params = stage.get_params()
    encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                decode_params)
    test_data = test_utils.TestData(x, encoded_x, decoded_x)
    test_data = self.evaluate_test_data(test_data)
Example #6
0
 def test_bits_out_of_range_raises(self, bits):
   with self.assertRaisesRegexp(ValueError, 'integer between 1 and 16'):
     quantization.PerChannelPRNGUniformQuantizationEncodingStage(bits=bits)
Example #7
0
 def default_encoding_stage(self):
   """See base class."""
   return quantization.PerChannelPRNGUniformQuantizationEncodingStage()