def test_clipping_large_norm_identity(self): stage = clipping.ClipByNormEncodingStage(1000.0) test_data = self.run_one_to_many_encode_decode( stage, lambda: tf.constant([1.0, 1.0, 1.0, 1.0])) self.common_asserts_for_test_data(test_data) # The encoding should act as an identity, if input value has smaller norm. self.assertAllEqual(test_data.x, test_data.decoded_x)
def test_clipping_effective(self): stage = clipping.ClipByNormEncodingStage(1.0) test_data = self.run_one_to_many_encode_decode( stage, lambda: tf.constant([1.0, 1.0, 1.0, 1.0])) self.common_asserts_for_test_data(test_data) self.assertAllEqual([1.0, 1.0, 1.0, 1.0], test_data.x) # The decoded values should have norm 1. self.assertAllClose([0.5, 0.5, 0.5, 0.5], test_data.decoded_x)
def test_input_types(self, x_dtype, clip_norm_dtype): # Tests combinations of input dtypes. stage = clipping.ClipByNormEncodingStage( tf.constant(1.0, clip_norm_dtype)) x = tf.constant([1.0, 1.0, 1.0, 1.0], dtype=x_dtype) encode_params, decode_params = stage.get_params() encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params, decode_params) test_data = test_utils.TestData(x, encoded_x, decoded_x) test_data = self.evaluate_test_data(test_data) self.assertAllEqual([1.0, 1.0, 1.0, 1.0], test_data.x) # The decoded values should have norm 1. self.assertAllClose([0.5, 0.5, 0.5, 0.5], test_data.decoded_x)
def test_different_shapes(self, shape): stage = clipping.ClipByNormEncodingStage(1.0) test_data = self.run_one_to_many_encode_decode( stage, lambda: tf.random.uniform(shape) + 1.0) self.common_asserts_for_test_data(test_data) self.assertAllClose(1.0, np.linalg.norm(test_data.decoded_x))
def default_encoding_stage(self): """See base class.""" return clipping.ClipByNormEncodingStage(1.0)