def test_clipping_large_min_max_identity(self):
     stage = clipping.ClipByValueEncodingStage(-1000.0, 1000.0)
     test_data = self.run_one_to_many_encode_decode(stage,
                                                    self.default_input)
     self.common_asserts_for_test_data(test_data)
     # The encoding should act as an identity, if input has smaller values.
     self.assertAllEqual(test_data.x, test_data.decoded_x)
 def test_different_shapes(self, shape):
     stage = clipping.ClipByValueEncodingStage(-1.0, 1.0)
     test_data = self.run_one_to_many_encode_decode(
         stage, lambda: tf.random.normal(shape))
     self.common_asserts_for_test_data(test_data)
     self.assertGreaterEqual(1.0, np.amax(test_data.decoded_x))
     self.assertLessEqual(-1.0, np.amin(test_data.decoded_x))
 def test_clipping_effective(self):
     stage = clipping.ClipByValueEncodingStage(-1.0, 1.0)
     test_data = self.run_one_to_many_encode_decode(
         stage, lambda: tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0]))
     self.common_asserts_for_test_data(test_data)
     self.assertAllEqual([-2.0, -1.0, 0.0, 1.0, 2.0], test_data.x)
     self.assertAllClose([-1.0, -1.0, 0.0, 1.0, 1.0], test_data.decoded_x)
    def test_input_types(self, x_dtype, clip_value_min_dtype,
                         clip_value_max_dtype):
        # Tests combinations of input dtypes.
        stage = clipping.ClipByValueEncodingStage(
            tf.constant(-1.0, clip_value_min_dtype),
            tf.constant(1.0, clip_value_max_dtype))
        x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=x_dtype)
        encode_params, decode_params = stage.get_params()
        encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                    decode_params)
        test_data = test_utils.TestData(x, encoded_x, decoded_x)
        test_data = self.evaluate_test_data(test_data)

        self.common_asserts_for_test_data(test_data)
        self.assertAllEqual([-2.0, -1.0, 0.0, 1.0, 2.0], test_data.x)
        self.assertAllClose([-1.0, -1.0, 0.0, 1.0, 1.0], test_data.decoded_x)
 def default_encoding_stage(self):
     """See base class."""
     return clipping.ClipByValueEncodingStage(-1.0, 1.0)