def test_assert_all_below_exception_raised(self, dtype): """Checks that assert_all_below raises exceptions for invalid input.""" vector = _pick_random_vector() vector = tf.convert_to_tensor(value=vector, dtype=dtype) vector = vector * vector vector /= tf.reduce_max(input_tensor=vector, axis=-1, keepdims=True) eps = asserts.select_eps_for_addition(dtype) outside_vector = vector + eps ones_vector = tf.ones_like(vector) with self.subTest(name="outside_and_open_bounds"): with self.assertRaises(tf.errors.InvalidArgumentError): self.evaluate( asserts.assert_all_below(outside_vector, 1.0, open_bound=True)) with self.subTest(name="outside_and_close_bounds"): with self.assertRaises(tf.errors.InvalidArgumentError): self.evaluate( asserts.assert_all_below(outside_vector, 1.0, open_bound=False)) with self.subTest(name="exact_and_open_bounds"): with self.assertRaises(tf.errors.InvalidArgumentError): self.evaluate( asserts.assert_all_below(ones_vector, 1.0, open_bound=True))
def test_assert_all_below_passthrough(self): """Checks that the assert is a passthrough when the flag is False.""" vector_input = _pick_random_vector() vector_output = asserts.assert_all_below(vector_input, 0.0) self.assertIs(vector_input, vector_output)