def test_all_zero_input_works(self): # Tests that encoding does not blow up with all-zero input. With # min_max=None, the derived min and max are identical, thus potential for # division by zero. stage = quantization.PRNGUniformQuantizationEncodingStage(bits=8) test_data = self.run_one_to_many_encode_decode( stage, lambda: tf.zeros([50])) self.assertAllEqual(np.zeros((50)).astype(np.float32), test_data.decoded_x)
def test_input_types(self, x_dtype): # Tests combinations of input dtypes. stage = quantization.PRNGUniformQuantizationEncodingStage(bits=8) x = tf.random.normal([50], dtype=x_dtype) encode_params, decode_params = stage.get_params() encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, decode_params) test_data = test_utils.TestData(x, encoded_x, decoded_x) test_data = self.evaluate_test_data(test_data)
def test_quantization_bits_stochastic_rounding(self, bits): stage = quantization.PRNGUniformQuantizationEncodingStage(bits=bits) test_data = self.run_one_to_many_encode_decode(stage, self.default_input) self._assert_is_integer_float(test_data.encoded_x[ quantization.PRNGUniformQuantizationEncodingStage.ENCODED_VALUES_KEY]) # For stochastic rounding, the potential error incurred by quantization # is bounded by the range of the input values divided by the number of # quantization buckets. self.assertAllClose( test_data.x, test_data.decoded_x, rtol=0.0, atol=2 / (2**bits - 1))
def test_dynamic_input_shape(self): # Tests that encoding works when the input shape is not statically known. stage = quantization.PRNGUniformQuantizationEncodingStage(bits=8) shape = [10, 5, 7] prob = [0.5, 0.8, 0.6] original_x = tf.random.uniform(shape, dtype=tf.float32) rand = [tf.random.uniform([shape[i],]) for i in range(3)] sample_indices = [ tf.reshape(tf.where(rand[i] < prob[i]), [-1]) for i in range(3) ] x = tf.gather(original_x, sample_indices[0], axis=0) x = tf.gather(x, sample_indices[1], axis=1) x = tf.gather(x, sample_indices[2], axis=2) encode_params, decode_params = stage.get_params() encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, decode_params) test_data = test_utils.TestData(x, encoded_x, decoded_x) test_data = self.evaluate_test_data(test_data)
def test_quantization_empirically_unbiased(self): # Tests that the quantization "seems" to be unbiased. # Executing the encoding and decoding many times, the average error should # be a lot larger than the error of average decoded value. x = tf.constant(np.random.rand((50)).astype(np.float32)) stage = quantization.PRNGUniformQuantizationEncodingStage(bits=2) encode_params, decode_params = stage.get_params() encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, decode_params) test_data = test_utils.TestData(x, encoded_x, decoded_x) test_data_list = [self.evaluate_test_data(test_data) for _ in range(200)] norm_errors = [] errors = [] for data in test_data_list: norm_errors.append(np.linalg.norm(data.x - data.decoded_x)) errors.append(data.x - data.decoded_x) mean_of_errors = np.mean(norm_errors) error_of_mean = np.linalg.norm(np.mean(errors, axis=0)) self.assertGreater(mean_of_errors, error_of_mean * 10)
def default_encoding_stage(self): """See base class.""" return quantization.PRNGUniformQuantizationEncodingStage()
def test_bits_out_of_range_raises(self, bits): with self.assertRaisesRegexp(ValueError, 'integer between 1 and 16'): quantization.PRNGUniformQuantizationEncodingStage(bits=bits)