def test_larger_num_iters_improves_accuracy(self): # If last_iter_clip = True, this potentially computes lossy representation. # Set delta large, to measure the effect of changing num_iters on accuracy. x = np.random.randn(3, 12).astype(np.float32) errors = [] seed = tf.constant([1, 2], tf.int64) for num_iters in [2, 3, 4, 5]: stage = kashin.KashinHadamardEncodingStage(num_iters=num_iters, eta=0.9, delta=100.0, last_iter_clip=True) encode_params, decode_params = stage.get_params() # To keep the experiment consistent, we always need to use fixed seed. encode_params[ kashin.KashinHadamardEncodingStage.SEED_PARAMS_KEY] = seed decode_params[ kashin.KashinHadamardEncodingStage.SEED_PARAMS_KEY] = seed encoded_x, decoded_x = self.encode_decode_x( stage, x, encode_params, decode_params) test_data = test_utils.TestData(x, encoded_x, decoded_x) test_data = self.evaluate_test_data(test_data) errors.append(np.linalg.norm(test_data.x - test_data.decoded_x)) for e1, e2 in zip(errors[:-1], errors[1:]): # The incurred error with less iterations should be greater. self.assertGreater(e1, e2)
def test_all_zero_input_works(self, last_iter_clip): # Tests that encoding does not blow up with all-zero input. stage = kashin.KashinHadamardEncodingStage( num_iters=3, eta=0.9, delta=1.0, last_iter_clip=last_iter_clip) test_data = self.run_one_to_many_encode_decode( stage, lambda: tf.zeros([3, 12])) self.common_asserts_for_test_data(test_data) self.assertAllEqual( np.zeros((3, 12)).astype(np.float32), test_data.decoded_x)
def test_eta_delta_take_tf_values(self): x = self.default_input() stage = kashin.KashinHadamardEncodingStage(eta=tf.constant(0.9), delta=tf.constant(1.0)) encode_params, decode_params = stage.get_params() encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, decode_params) test_data = test_utils.TestData(x, encoded_x, decoded_x) self.generic_asserts(test_data, stage) self.common_asserts_for_test_data(test_data)
def test_last_iter_clip_false_is_lossless(self): # Make sure to set delta to something large so that there is something to # clip in the last iteration. Otherwise the test does not make sense. stage = kashin.KashinHadamardEncodingStage(num_iters=2, eta=0.9, delta=100.0, last_iter_clip=False) test_data = self.run_one_to_many_encode_decode( stage, lambda: tf.random.normal([3, 12])) self.assertAllClose(test_data.x, test_data.decoded_x)
def test_with_multiple_input_shapes_pad_0_8(self, input_dims, expected_output_dims): stage = kashin.KashinHadamardEncodingStage( pad_extra_level_threshold=0.8) test_data = self.run_one_to_many_encode_decode( stage, lambda: tf.random.normal(input_dims)) self.common_asserts_for_test_data(test_data) # Make sure output shape is as expected. self.assertEqual( expected_output_dims, test_data.encoded_x[ kashin.KashinHadamardEncodingStage.ENCODED_VALUES_KEY].shape)
def test_input_types(self, x_dtype, eta_dtype, delta_dtype): stage = kashin.KashinHadamardEncodingStage( eta=tf.constant(0.9, eta_dtype), delta=tf.constant(1.0, delta_dtype)) x = tf.random.normal([3, 12], dtype=x_dtype) encode_params, decode_params = stage.get_params() encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, decode_params) test_data = test_utils.TestData(x, encoded_x, decoded_x) test_data = self.evaluate_test_data(test_data) self.assertAllEqual(test_data.x.shape, test_data.decoded_x.shape)
def default_encoding_stage(self): """See base class.""" return kashin.KashinHadamardEncodingStage()
def test_pad_extra_level_threshold_tensor_raises(self): with self.assertRaisesRegexp(ValueError, 'pad_extra_level_threshold'): kashin.KashinHadamardEncodingStage( pad_extra_level_threshold=tf.constant(0.8, dtype=tf.float32))
def test_last_iter_clip_not_bool_raises(self, last_iter_clip): with self.assertRaisesRegexp(ValueError, 'last_iter_clip must be a bool'): kashin.KashinHadamardEncodingStage(last_iter_clip=last_iter_clip)
def test_last_iter_clip_tensor_raises(self): with self.assertRaisesRegexp(ValueError, 'last_iter_clip'): kashin.KashinHadamardEncodingStage( last_iter_clip=tf.constant(True, dtype=tf.bool))
def test_num_iters_tensor_raises(self): with self.assertRaisesRegexp(ValueError, 'num_iters'): kashin.KashinHadamardEncodingStage( num_iters=tf.constant(2, dtype=tf.int32))
def test_num_iters_small_raises(self, num_iters): with self.assertRaisesRegexp(ValueError, 'positive'): kashin.KashinHadamardEncodingStage(num_iters=num_iters)
def test_delta_small_raises(self, delta): with self.assertRaisesRegexp(ValueError, 'greater than 0'): kashin.KashinHadamardEncodingStage(delta=delta)
def test_eta_out_of_bounds_raises(self, eta): with self.assertRaisesRegexp(ValueError, 'between 0 and 1'): kashin.KashinHadamardEncodingStage(eta=eta)