def test_validate_args_vanilla_function(self): def example_function(a, b, c="a", d=6): return 1.0 enforcer = Enforcer(example_function) enforcer.verify_args([1, 2, 3, "a"], {})
def test_validate_args_list(self): def example_function(a: Int, *args: List(Str)): pass enforcer = Enforcer(example_function) enforcer.verify_args([1, ["a", "b", "c"]], {})
def test_validate_args_with_type_hints(self): def example_function(a: int, b, c: str = "aa", d=5): pass enforcer = Enforcer(example_function) enforcer.verify_args([1, 2, "string", "string2"], {}) enforcer.verify_args([], {"d": 10, "c": "bb", "b": "cc", "a": 2})
def test_widening_coercion(self): def example_function(a: float, b: float) -> float: pass enforcer = Enforcer(example_function) enforcer.verify_args([1, 1], {}) # No errors enforcer.verify_result(1) # No errors
def test_validate_trait_types(self): def example_function(a: Str, b: Int) -> Float: pass enforcer = Enforcer(example_function) enforcer.verify_args(["aa", 0], {}) enforcer.verify_args([], {"b": 0, "a": "string"}) enforcer.verify_result(0.1) enforcer.verify_result(1)
class EnforceTypeHints: def __init__(self, func, require_args, require_return): self.func = func self.enforcer = None self.require_args = require_args self.require_return = require_return def __call__(self, *args, **kwargs): if self.enforcer is None: self.decorate() return self.decorated_func(*args, **kwargs) def __set_name__(self, owner, name): # This is called on class creation so we can distinguish methods # from non-methods ignore_self = True # Get the actual function if this is a static or class method desc = None if isinstance(self.func, staticmethod): # Static methods don't take a `self` parameter desc = staticmethod ignore_self = False elif isinstance(self.func, classmethod): desc = classmethod if desc: self.func = self.func.__func__ self.decorate(ignore_self=ignore_self) if desc: self.decorated_func = desc(self.decorated_func) setattr(owner, name, self.decorated_func) def decorate(self, ignore_self=False): self.enforcer = Enforcer( self.func, require_args=self.require_args, require_return=self.require_return, ignore_self=ignore_self, ) @wraps(self.func) def new_func(*args, **kwargs): self.enforcer.verify_args(args, kwargs) result = self.func(*args, **kwargs) self.enforcer.verify_result(result) return result self.decorated_func = new_func
def test_validate_args_with_invalid_defaults(self): def example_function(a: int, b, c: int = "a", d=6) -> float: return 1.0 enforcer = Enforcer(example_function) with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args([], {}) self.assertEqual( "The 'c' parameter of 'example_function' must be <class 'int'>, " "but a value of 'a' <class 'str'> was specified.", str(err.exception))
def test_validete_with_invalid_defaults(self): def example_function(a: Str = "a", b: Int = "b"): pass enforcer = Enforcer(example_function) # Emulate passing no args, i.e. using defaults with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args([], {}) self.assertEqual( "The 'b' parameter of 'example_function' must be " "<class 'traits.trait_types.Int'>, but a value of 'b' " "<class 'str'> was specified.", str(err.exception))
def test_validate_tuple(self): def example_function(a: Tuple(Str, Int)) -> Tuple(Int, Str): pass enforcer = Enforcer(example_function) enforcer.verify_args([("a", 2)], {}) enforcer.verify_result((2, "b")) with self.assertRaises(ParameterTypeError): enforcer.verify_args([(2, "a")], {}) with self.assertRaises(ReturnTypeError): enforcer.verify_result(("b", 2))
def test_validate_args_invalid_args(self): def example_function(a: Str, b: Int) -> Float: pass enforcer = Enforcer(example_function) with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args([90, 10], {}) self.assertEqual( "The 'a' parameter of 'example_function' must be " "<class 'traits.trait_types.Str'>, but a value of 90 " "<class 'int'> was specified.", str(err.exception)) with self.assertRaises(ReturnTypeError) as err: enforcer.verify_result("bad") self.assertEqual( "The return type of 'example_function' must be " "<class 'traits.trait_types.Float'>, but a value of 'bad' " "<class 'str'> was returned.", str(err.exception))
def test_narrowing_coercion(self): def example_function(a: int, b: int) -> int: pass enforcer = Enforcer(example_function) with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args([1.0, 2.0], {}) self.assertEqual( "The 'a' parameter of 'example_function' must be <class 'int'>, " "but a value of 1.0 <class 'float'> was specified.", str(err.exception)) with self.assertRaises(ReturnTypeError) as err: enforcer.verify_result(1.0) self.assertEqual( "The return type of 'example_function' must be <class 'int'>, " "but a value of 1.0 <class 'float'> was returned.", str(err.exception))
def test_validate_args_invalid_args(self): def example_function(a: int, b, c: str = "aa", d=5): pass enforcer = Enforcer(example_function) with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args([1, "ok", 0, "ok"], {}) self.assertEqual( "The 'c' parameter of 'example_function' must be <class 'str'>, " "but a value of 0 <class 'int'> was specified.", str(err.exception)) with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args([], {"d": 10, "c": "ok", "b": "cc", "a": "y"}) self.assertEqual( "The 'a' parameter of 'example_function' must be <class 'int'>, " "but a value of 'y' <class 'str'> was specified.", str(err.exception))
def test_validate_args_multiple_invalid_args_order(self): def example_function(a: int, b, c: str = "aa", d=5): pass enforcer = Enforcer(example_function) with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args(["bad", "ok", 0, "ok"], {}) # The first invalid arg raises self.assertEqual( "The 'a' parameter of 'example_function' must be <class 'int'>, " "but a value of 'bad' <class 'str'> was specified.", str(err.exception)) with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args([], {"d": 10, "c": 200, "b": "cc", "a": "y"}) # The first defined kwarg raises self.assertEqual( "The 'a' parameter of 'example_function' must be <class 'int'>, " "but a value of 'y' <class 'str'> was specified.", str(err.exception)) with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args(["y", "cc"], {"d": 10, "c": 200}) # Arg raises before kwarg self.assertEqual( "The 'a' parameter of 'example_function' must be <class 'int'>, " "but a value of 'y' <class 'str'> was specified.", str(err.exception))
def test_validate_instance(self): class MyClass: pass class BabClass(MyClass): pass class UnrelatedClass: pass def example_function(a: Instance(MyClass)) -> Instance(BabClass): pass enforcer = Enforcer(example_function) parent = MyClass() child = BabClass() other = UnrelatedClass() enforcer.verify_args([parent], {}) enforcer.verify_args([child], {}) enforcer.verify_result(child) with self.assertRaises(ParameterTypeError): enforcer.verify_args([other], {}) with self.assertRaises(ReturnTypeError): enforcer.verify_result(parent) with self.assertRaises(ReturnTypeError): enforcer.verify_result(other)
def test_validate_with_none(self): def example_function(a: None) -> None: pass enforcer = Enforcer(example_function) enforcer.verify_args([None], {}) enforcer.verify_result(None) with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args([0], {}) self.assertEqual( "The 'a' parameter of 'example_function' must be None, " "but a value of 0 <class 'int'> was specified.", str(err.exception)) with self.assertRaises(ReturnTypeError) as err: enforcer.verify_result(0) self.assertEqual( "The return type of 'example_function' must be None, " "but a value of 0 <class 'int'> was returned.", str(err.exception))
def test_validate_with_lists(self): def example_function(a: list) -> list: pass enforcer = Enforcer(example_function) enforcer.verify_args([[1, 2, "a"]], {}) enforcer.verify_result([]) with self.assertRaises(ParameterTypeError) as err: enforcer.verify_args([(1, 2, "a")], {}) self.assertEqual( "The 'a' parameter of 'example_function' must be <class 'list'>, " "but a value of (1, 2, 'a') <class 'tuple'> was specified.", str(err.exception)) with self.assertRaises(ReturnTypeError) as err: enforcer.verify_result(tuple()) self.assertEqual( "The return type of 'example_function' must be <class 'list'>, " "but a value of () <class 'tuple'> was returned.", str(err.exception))
def test_validate_either(self): def example_function(a: Either(Str, Int)) -> Either(Int, Str): pass enforcer = Enforcer(example_function) enforcer.verify_args(["a"], {}) enforcer.verify_args([1], {}) enforcer.verify_result(2) enforcer.verify_result("b") with self.assertRaises(ParameterTypeError): enforcer.verify_args([1.0], {}) with self.assertRaises(ReturnTypeError): enforcer.verify_result(1.0)
def test_validate_enum(self): def example_function(a: Enum(1, 5, "c", 4.0)) -> Enum(7, "d"): pass enforcer = Enforcer(example_function) enforcer.verify_args([1], {}) enforcer.verify_args([5], {}) enforcer.verify_args(["c"], {}) enforcer.verify_args([4.0], {}) enforcer.verify_result(7) enforcer.verify_result("d") with self.assertRaises(ParameterTypeError): enforcer.verify_args([7], {}) with self.assertRaises(ReturnTypeError): enforcer.verify_result("c")