Esempio n. 1
0
def test_from_parse():
    def f(x):
        return 2 * x + 1

    g = parse(f)
    manage(g)

    print(print_graph(g))

    cf = compile_graph(g, debug=True)
    assert cf(0) == 1, cf(0)
    assert cf(-2) == -3, cf(0)
Esempio n. 2
0
def test_annotation_parsing_local_import():
    # Just import something not imported globally,
    # to test parsing with local imports.
    from array import ArrayType

    def f(a: ArrayType) -> ArrayType:
        b: ArrayType = np.arange(10)
        return a.sum() + b.sum()

    graph = raw_parse(f)
    manager = manage(graph)

    # Check parameters annotation.
    parameters = {p.debug.debug_name: p for p in graph.parameters}
    assert parameters["a"].annotation is ArrayType

    # Check return annotation.
    assert graph.return_.annotation is ArrayType

    # Check variable annotations.
    variables_checked = 0
    for node in manager.all_nodes:
        name = node.debug.debug_name
        if name == "b":
            assert node.annotation is ArrayType
            variables_checked += 1
    assert variables_checked == 1
Esempio n. 3
0
def test_keep_roots_recursion():
    def nonrec():
        return 123

    def rec1():
        return rec2()

    def rec2():
        return rec1()

    @parse
    def f(x, y):
        return rec1() + nonrec()

    mng = manage(f)
    assert len(mng.graphs) == 4

    mng.replace(f.output, f.parameters[0])
    # This reclaimed nonrec, but failed to reclaim rec1 and rec2 because of a
    # cycle in references
    assert len(mng.graphs) == 3

    # This acts like a GC, so the recursive graphs will be cut
    mng.keep_roots()
    assert len(mng.graphs) == 1
Esempio n. 4
0
def test_annotation_parsing_typing():

    # Type annotation for b is wrong, but we use is here just for testing.
    def f(a: int, b: List[int]) -> bool:
        c: tuple = (2, 3)
        d: int = int(b + 1.5)
        return bool(a * b) * c[0] + d

    graph = raw_parse(f)
    manager = manage(graph)

    # Check parameters annotation.
    parameters = {p.debug.debug_name: p for p in graph.parameters}
    assert parameters["a"].annotation is int
    assert parameters["b"].annotation is not List
    assert parameters["b"].annotation is List[int]

    # Check return annotation.
    assert graph.return_.annotation is bool

    # Check variable annotations.
    variables_checked = 0
    for node in manager.all_nodes:
        name = node.debug.debug_name
        if name == "c":
            assert node.annotation is tuple
            variables_checked += 1
        elif name == "d":
            assert node.annotation is int
            variables_checked += 1
    assert variables_checked == 2
Esempio n. 5
0
def test_drop_root():
    @parse
    def f(x, y):
        return x * y

    mng = manage(f)
    assert f in mng.nodes
    mng._maybe_drop_graphs({f})
    assert f in mng.nodes
Esempio n. 6
0
def test_cannot_replace_return():
    @parse
    def f(x):
        return x * x

    mng = manage(f)

    with pytest.raises(ManagerError):
        mng.replace(f.return_, f.parameters[0])
Esempio n. 7
0
def test_set_output():
    @parse
    def f(x, y):
        return x * y

    mng = manage(f)
    assert f.manager is mng
    assert len(f.nodes) == 4
    f.output = f.parameters[0]
    assert len(f.nodes) == 3
Esempio n. 8
0
def test_manager_exclusivity():
    @parse
    def f(x):
        return x * x

    mng = manage(f)
    assert f._manager is mng

    with pytest.raises(ManagerError):
        GraphManager(f)
Esempio n. 9
0
 def _import_graph(self, graph):
     mng = manage(graph, weak=True)
     graphs = set()
     parents = mng.parents
     g = graph
     while g:
         graphs.add(g)
         g = parents[g]
     clone = GraphCloner(*graphs, total=True, relation='cosmetic')
     self.graphs |= {clone[g] for g in graphs}
     self.focus.add(clone[graph])
Esempio n. 10
0
def test_add_parameter():
    @parse
    def f(x, y):
        return x * y

    mng = manage(f)
    assert len(f.parameters) == 2
    assert set(f.parameters).issubset(mng.nodes[f])
    f.add_parameter()
    assert len(f.parameters) == 3
    assert set(f.parameters).issubset(mng.nodes[f])
Esempio n. 11
0
def check_no_free_variables(root):
    mng = manage(root)
    for g, nodes in mng.nodes.items():
        if not g:
            continue
        if g.parent is not None:
            raise Exception(f"Nested graph detected: {g}")
        for node in nodes:
            assert node.graph is g
            for inp in node.inputs:
                if inp.graph is not None and inp.graph is not g:
                    raise Exception(f"Free variable detected: {node}")
Esempio n. 12
0
    def run(self, graph, backend):
        """Compile given graph.

        :type backend: PythonBackend
        """
        mng = manage(graph)
        mng.keep_roots(graph)
        # Graph to name
        for g in mng.graphs:
            if g is graph:
                self.graph_to_name[g] = "main"
            else:
                self.graph_to_name[g] = self.get_label(g)
        # Graph name to function code
        for g, g_name in self.graph_to_name.items():
            self.fn_name_to_code[g_name] = self.convert_func(g)
        # Compilation.
        pre_code = [
            "import math",
            "import operator",
            "import numpy as np",
            "from myia.utils import RandomStateWrapper",
            "from myia.lib import TaggedValue",
            "from myia.utils.universe import HandleInstance",
            "import myia.compile.backends.python.implementations as IMPL",
        ]
        other_functions = []
        main_body = None
        main_signature = None
        for fn_name, (fn_params, fn_body) in self.fn_name_to_code.items():
            if fn_name == "main":
                main_body = fn_body
                main_signature = f"def main({', '.join(fn_params)}):"
            else:
                fn_signature = f"def {fn_name}({', '.join(fn_params)}):"
                other_functions.append(fn_signature)
                other_functions.append(fn_body)
        final_structure = (pre_code + other_functions +
                           [main_signature, main_body])
        final_code = nested_list_to_code_string(final_structure)

        if backend.debug:
            backend.debug.write(f"\n{final_code}")

        if backend.pdb:
            return PdbRunCall(final_code)

        # Compile code string to a Python executable function
        # reference: https://stackoverflow.com/a/19850183
        compiled = compile(final_code, "", "exec")
        module = ModuleType("mod")
        exec(compiled, module.__dict__)
        return getattr(module, "main")
Esempio n. 13
0
    def run(self, graph, context, target, exec_kind):
        """Convert the graph into a relay callable."""
        mng = manage(graph)

        graph, handles_params = return_handles(graph)

        mng.keep_roots(graph)

        self.module = tvm.IRModule({})
        self.types = TypeHelper()
        self.types.initialize(self.module, mng)
        self.make_const = RelayConstantConverter(context, self.types)
        self.universe_helper = None
        self.i = 0

        # Analyze and create a global union type of all the possible types
        # and then use it for all union values.

        function_map = {}
        self.node_map = {}
        self.graph_map = {}

        for g in mng.graphs:
            if g.parent is None:
                if g is graph:
                    self.graph_map[g] = relay.GlobalVar("main")
                else:
                    # Mangle user names
                    name = "_" + g.debug.debug_name
                    self.graph_map[g] = relay.GlobalVar(name)

        for g in self.graph_map.keys():
            function_map[self.graph_map[g]] = self.convert_func(g)

        add_functions(self.module, function_map)

        vm = relay.create_executor(mod=self.module,
                                   ctx=context,
                                   target=target,
                                   kind=exec_kind)
        res = vm.evaluate()

        fill_reverse_tag_map()

        res = handle_wrapper(res, handles_params)

        return res
Esempio n. 14
0
def make_handle_to_make_cell(g):
    """Replace uset(*make_handle(typ), value) by make_cell(value, U).

    This is because RefCreate both creates the reference and sets it.
    """
    mng = manage(g)
    for node in list(mng.all_nodes):
        equiv = node.match((
            P.universe_setitem,
            (P.tuple_getitem, X, 0),
            (P.tuple_getitem, X, 1),
            Y,
        ))
        if equiv:
            x = equiv[X]
            if x.is_apply(P.make_handle):
                new_handle_node = sexp_to_node(
                    (make_cell, equiv[Y], x.inputs[2]), node.graph)
                mng.replace(x, new_handle_node)
                mng.replace(node, node.inputs[1])
Esempio n. 15
0
def test_keep_roots():
    @clone
    @parse
    def f(x, y):
        return x * y

    @clone
    @parse
    def g(x, y):
        return x + y

    mng = manage(f)
    assert mng.graphs == OrderedSet([f])

    mng.add_graph(g)
    assert mng.graphs == OrderedSet([f, g])

    mng.keep_roots()
    assert mng.graphs == OrderedSet([f])

    mng.keep_roots(g)
    assert mng.graphs == OrderedSet([g])
Esempio n. 16
0
def test_annotation_parsing_numpy():
    def f(a: np.ndarray) -> np.ndarray:
        b: np.ndarray = np.arange(10)
        return a.sum() + b.sum()

    graph = raw_parse(f)
    manager = manage(graph)

    # Check parameters annotation.
    parameters = {p.debug.debug_name: p for p in graph.parameters}
    assert parameters["a"].annotation is np.ndarray

    # Check return annotation.
    assert graph.return_.annotation is np.ndarray

    # Check variable annotations.
    variables_checked = 0
    for node in manager.all_nodes:
        name = node.debug.debug_name
        if name == "b":
            assert node.annotation is np.ndarray
            variables_checked += 1
    assert variables_checked == 1
Esempio n. 17
0
 def compile(self, graph, *others):
     """Compile a graph."""
     manage(graph)
     graph = closure_convert(graph)
     return self.compiler.compile_and_link(graph)