def test_annotation_parsing_local_import(): # Just import something not imported globally, # to test parsing with local imports. from array import ArrayType def f(a: ArrayType) -> ArrayType: b: ArrayType = np.arange(10) return a.sum() + b.sum() graph = raw_parse(f) manager = manage(graph) # Check parameters annotation. parameters = {p.debug.debug_name: p for p in graph.parameters} assert parameters["a"].annotation is ArrayType # Check return annotation. assert graph.return_.annotation is ArrayType # Check variable annotations. variables_checked = 0 for node in manager.all_nodes: name = node.debug.debug_name if name == "b": assert node.annotation is ArrayType variables_checked += 1 assert variables_checked == 1
def test_annotation_parsing_typing(): # Type annotation for b is wrong, but we use is here just for testing. def f(a: int, b: List[int]) -> bool: c: tuple = (2, 3) d: int = int(b + 1.5) return bool(a * b) * c[0] + d graph = raw_parse(f) manager = manage(graph) # Check parameters annotation. parameters = {p.debug.debug_name: p for p in graph.parameters} assert parameters["a"].annotation is int assert parameters["b"].annotation is not List assert parameters["b"].annotation is List[int] # Check return annotation. assert graph.return_.annotation is bool # Check variable annotations. variables_checked = 0 for node in manager.all_nodes: name = node.debug.debug_name if name == "c": assert node.annotation is tuple variables_checked += 1 elif name == "d": assert node.annotation is int variables_checked += 1 assert variables_checked == 2
def test_parametric(): def f(x, y=6): return x + y def g(x, y, *args): return x + args[0] def h(*args): return f(*args) * g(*args) def i(x, *, y): return x + y def j(**kwargs): return kwargs assert raw_parse(f).defaults == ['y'] assert raw_parse(g).vararg == 'args' assert raw_parse(h).vararg == 'args' assert raw_parse(i).kwonly == 1 assert raw_parse(j).kwarg == 'kwargs'
def test_annotation_parsing_numpy(): def f(a: np.ndarray) -> np.ndarray: b: np.ndarray = np.arange(10) return a.sum() + b.sum() graph = raw_parse(f) manager = manage(graph) # Check parameters annotation. parameters = {p.debug.debug_name: p for p in graph.parameters} assert parameters["a"].annotation is np.ndarray # Check return annotation. assert graph.return_.annotation is np.ndarray # Check variable annotations. variables_checked = 0 for node in manager.all_nodes: name = node.debug.debug_name if name == "b": assert node.annotation is np.ndarray variables_checked += 1 assert variables_checked == 1
def test_fn_param_same_name(): def a(a): return a + 1 fa = raw_parse(a) assert fa.output.inputs[1] is fa.parameters[0]