def test_clipping_large_norm_identity(self):
     stage = clipping.ClipByNormEncodingStage(1000.0)
     test_data = self.run_one_to_many_encode_decode(
         stage, lambda: tf.constant([1.0, 1.0, 1.0, 1.0]))
     self.common_asserts_for_test_data(test_data)
     # The encoding should act as an identity, if input value has smaller norm.
     self.assertAllEqual(test_data.x, test_data.decoded_x)
 def test_clipping_effective(self):
     stage = clipping.ClipByNormEncodingStage(1.0)
     test_data = self.run_one_to_many_encode_decode(
         stage, lambda: tf.constant([1.0, 1.0, 1.0, 1.0]))
     self.common_asserts_for_test_data(test_data)
     self.assertAllEqual([1.0, 1.0, 1.0, 1.0], test_data.x)
     # The decoded values should have norm 1.
     self.assertAllClose([0.5, 0.5, 0.5, 0.5], test_data.decoded_x)
    def test_input_types(self, x_dtype, clip_norm_dtype):
        # Tests combinations of input dtypes.
        stage = clipping.ClipByNormEncodingStage(
            tf.constant(1.0, clip_norm_dtype))
        x = tf.constant([1.0, 1.0, 1.0, 1.0], dtype=x_dtype)
        encode_params, decode_params = stage.get_params()
        encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                    decode_params)
        test_data = test_utils.TestData(x, encoded_x, decoded_x)
        test_data = self.evaluate_test_data(test_data)

        self.assertAllEqual([1.0, 1.0, 1.0, 1.0], test_data.x)
        # The decoded values should have norm 1.
        self.assertAllClose([0.5, 0.5, 0.5, 0.5], test_data.decoded_x)
 def test_different_shapes(self, shape):
     stage = clipping.ClipByNormEncodingStage(1.0)
     test_data = self.run_one_to_many_encode_decode(
         stage, lambda: tf.random.uniform(shape) + 1.0)
     self.common_asserts_for_test_data(test_data)
     self.assertAllClose(1.0, np.linalg.norm(test_data.decoded_x))
 def default_encoding_stage(self):
     """See base class."""
     return clipping.ClipByNormEncodingStage(1.0)