def test_from_parse(): def f(x): return 2 * x + 1 g = parse(f) manage(g) print(print_graph(g)) cf = compile_graph(g, debug=True) assert cf(0) == 1, cf(0) assert cf(-2) == -3, cf(0)
def test_annotation_parsing_local_import(): # Just import something not imported globally, # to test parsing with local imports. from array import ArrayType def f(a: ArrayType) -> ArrayType: b: ArrayType = np.arange(10) return a.sum() + b.sum() graph = raw_parse(f) manager = manage(graph) # Check parameters annotation. parameters = {p.debug.debug_name: p for p in graph.parameters} assert parameters["a"].annotation is ArrayType # Check return annotation. assert graph.return_.annotation is ArrayType # Check variable annotations. variables_checked = 0 for node in manager.all_nodes: name = node.debug.debug_name if name == "b": assert node.annotation is ArrayType variables_checked += 1 assert variables_checked == 1
def test_keep_roots_recursion(): def nonrec(): return 123 def rec1(): return rec2() def rec2(): return rec1() @parse def f(x, y): return rec1() + nonrec() mng = manage(f) assert len(mng.graphs) == 4 mng.replace(f.output, f.parameters[0]) # This reclaimed nonrec, but failed to reclaim rec1 and rec2 because of a # cycle in references assert len(mng.graphs) == 3 # This acts like a GC, so the recursive graphs will be cut mng.keep_roots() assert len(mng.graphs) == 1
def test_annotation_parsing_typing(): # Type annotation for b is wrong, but we use is here just for testing. def f(a: int, b: List[int]) -> bool: c: tuple = (2, 3) d: int = int(b + 1.5) return bool(a * b) * c[0] + d graph = raw_parse(f) manager = manage(graph) # Check parameters annotation. parameters = {p.debug.debug_name: p for p in graph.parameters} assert parameters["a"].annotation is int assert parameters["b"].annotation is not List assert parameters["b"].annotation is List[int] # Check return annotation. assert graph.return_.annotation is bool # Check variable annotations. variables_checked = 0 for node in manager.all_nodes: name = node.debug.debug_name if name == "c": assert node.annotation is tuple variables_checked += 1 elif name == "d": assert node.annotation is int variables_checked += 1 assert variables_checked == 2
def test_drop_root(): @parse def f(x, y): return x * y mng = manage(f) assert f in mng.nodes mng._maybe_drop_graphs({f}) assert f in mng.nodes
def test_cannot_replace_return(): @parse def f(x): return x * x mng = manage(f) with pytest.raises(ManagerError): mng.replace(f.return_, f.parameters[0])
def test_set_output(): @parse def f(x, y): return x * y mng = manage(f) assert f.manager is mng assert len(f.nodes) == 4 f.output = f.parameters[0] assert len(f.nodes) == 3
def test_manager_exclusivity(): @parse def f(x): return x * x mng = manage(f) assert f._manager is mng with pytest.raises(ManagerError): GraphManager(f)
def _import_graph(self, graph): mng = manage(graph, weak=True) graphs = set() parents = mng.parents g = graph while g: graphs.add(g) g = parents[g] clone = GraphCloner(*graphs, total=True, relation='cosmetic') self.graphs |= {clone[g] for g in graphs} self.focus.add(clone[graph])
def test_add_parameter(): @parse def f(x, y): return x * y mng = manage(f) assert len(f.parameters) == 2 assert set(f.parameters).issubset(mng.nodes[f]) f.add_parameter() assert len(f.parameters) == 3 assert set(f.parameters).issubset(mng.nodes[f])
def check_no_free_variables(root): mng = manage(root) for g, nodes in mng.nodes.items(): if not g: continue if g.parent is not None: raise Exception(f"Nested graph detected: {g}") for node in nodes: assert node.graph is g for inp in node.inputs: if inp.graph is not None and inp.graph is not g: raise Exception(f"Free variable detected: {node}")
def run(self, graph, backend): """Compile given graph. :type backend: PythonBackend """ mng = manage(graph) mng.keep_roots(graph) # Graph to name for g in mng.graphs: if g is graph: self.graph_to_name[g] = "main" else: self.graph_to_name[g] = self.get_label(g) # Graph name to function code for g, g_name in self.graph_to_name.items(): self.fn_name_to_code[g_name] = self.convert_func(g) # Compilation. pre_code = [ "import math", "import operator", "import numpy as np", "from myia.utils import RandomStateWrapper", "from myia.lib import TaggedValue", "from myia.utils.universe import HandleInstance", "import myia.compile.backends.python.implementations as IMPL", ] other_functions = [] main_body = None main_signature = None for fn_name, (fn_params, fn_body) in self.fn_name_to_code.items(): if fn_name == "main": main_body = fn_body main_signature = f"def main({', '.join(fn_params)}):" else: fn_signature = f"def {fn_name}({', '.join(fn_params)}):" other_functions.append(fn_signature) other_functions.append(fn_body) final_structure = (pre_code + other_functions + [main_signature, main_body]) final_code = nested_list_to_code_string(final_structure) if backend.debug: backend.debug.write(f"\n{final_code}") if backend.pdb: return PdbRunCall(final_code) # Compile code string to a Python executable function # reference: https://stackoverflow.com/a/19850183 compiled = compile(final_code, "", "exec") module = ModuleType("mod") exec(compiled, module.__dict__) return getattr(module, "main")
def run(self, graph, context, target, exec_kind): """Convert the graph into a relay callable.""" mng = manage(graph) graph, handles_params = return_handles(graph) mng.keep_roots(graph) self.module = tvm.IRModule({}) self.types = TypeHelper() self.types.initialize(self.module, mng) self.make_const = RelayConstantConverter(context, self.types) self.universe_helper = None self.i = 0 # Analyze and create a global union type of all the possible types # and then use it for all union values. function_map = {} self.node_map = {} self.graph_map = {} for g in mng.graphs: if g.parent is None: if g is graph: self.graph_map[g] = relay.GlobalVar("main") else: # Mangle user names name = "_" + g.debug.debug_name self.graph_map[g] = relay.GlobalVar(name) for g in self.graph_map.keys(): function_map[self.graph_map[g]] = self.convert_func(g) add_functions(self.module, function_map) vm = relay.create_executor(mod=self.module, ctx=context, target=target, kind=exec_kind) res = vm.evaluate() fill_reverse_tag_map() res = handle_wrapper(res, handles_params) return res
def make_handle_to_make_cell(g): """Replace uset(*make_handle(typ), value) by make_cell(value, U). This is because RefCreate both creates the reference and sets it. """ mng = manage(g) for node in list(mng.all_nodes): equiv = node.match(( P.universe_setitem, (P.tuple_getitem, X, 0), (P.tuple_getitem, X, 1), Y, )) if equiv: x = equiv[X] if x.is_apply(P.make_handle): new_handle_node = sexp_to_node( (make_cell, equiv[Y], x.inputs[2]), node.graph) mng.replace(x, new_handle_node) mng.replace(node, node.inputs[1])
def test_keep_roots(): @clone @parse def f(x, y): return x * y @clone @parse def g(x, y): return x + y mng = manage(f) assert mng.graphs == OrderedSet([f]) mng.add_graph(g) assert mng.graphs == OrderedSet([f, g]) mng.keep_roots() assert mng.graphs == OrderedSet([f]) mng.keep_roots(g) assert mng.graphs == OrderedSet([g])
def test_annotation_parsing_numpy(): def f(a: np.ndarray) -> np.ndarray: b: np.ndarray = np.arange(10) return a.sum() + b.sum() graph = raw_parse(f) manager = manage(graph) # Check parameters annotation. parameters = {p.debug.debug_name: p for p in graph.parameters} assert parameters["a"].annotation is np.ndarray # Check return annotation. assert graph.return_.annotation is np.ndarray # Check variable annotations. variables_checked = 0 for node in manager.all_nodes: name = node.debug.debug_name if name == "b": assert node.annotation is np.ndarray variables_checked += 1 assert variables_checked == 1
def compile(self, graph, *others): """Compile a graph.""" manage(graph) graph = closure_convert(graph) return self.compiler.compile_and_link(graph)