Example #1
0
    def test_assert_all_below_exception_raised(self, dtype):
        """Checks that assert_all_below raises exceptions for invalid input."""
        vector = _pick_random_vector()
        vector = tf.convert_to_tensor(value=vector, dtype=dtype)

        vector = vector * vector
        vector /= tf.reduce_max(input_tensor=vector, axis=-1, keepdims=True)
        eps = asserts.select_eps_for_addition(dtype)
        outside_vector = vector + eps
        ones_vector = tf.ones_like(vector)

        with self.subTest(name="outside_and_open_bounds"):
            with self.assertRaises(tf.errors.InvalidArgumentError):
                self.evaluate(
                    asserts.assert_all_below(outside_vector,
                                             1.0,
                                             open_bound=True))

        with self.subTest(name="outside_and_close_bounds"):
            with self.assertRaises(tf.errors.InvalidArgumentError):
                self.evaluate(
                    asserts.assert_all_below(outside_vector,
                                             1.0,
                                             open_bound=False))

        with self.subTest(name="exact_and_open_bounds"):
            with self.assertRaises(tf.errors.InvalidArgumentError):
                self.evaluate(
                    asserts.assert_all_below(ones_vector, 1.0,
                                             open_bound=True))
Example #2
0
    def test_assert_all_below_passthrough(self):
        """Checks that the assert is a passthrough when the flag is False."""
        vector_input = _pick_random_vector()
        vector_output = asserts.assert_all_below(vector_input, 0.0)

        self.assertIs(vector_input, vector_output)